diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index ff0965c56bb..26157d7d23d 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -624,6 +624,8 @@ def __post_init__(self): raise NotImplementedError( f"not support model_impl: '{self.model_impl}'. " f"Must be one of: {', '.join(valid_model_impls)}" ) + if envs.FD_ENABLE_RL == 1: + self.moe_gate_fp32 = True self.post_init_all_ports() diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 0c7ac3e22b1..5534871fcb3 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -266,6 +266,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool( int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1")) ), + # Whether to align RoPE and moe gate precision with training + "FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")), } diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index 3f45e9df614..64b82229d37 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -150,9 +150,7 @@ def __init__( output_size=fd_config.model_config.n_routed_experts, with_bias=False, skip_quant=True, - weight_dtype=( - "float32" if fd_config.load_config.dynamic_load_weight or fd_config.model_config.moe_gate_fp32 else "" - ), + weight_dtype=("float32" if fd_config.model_config.moe_gate_fp32 else ""), ) self.gate.e_score_correction_bias = self.create_parameter( shape=[1, fd_config.model_config.n_routed_experts], diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index 74ca37ab695..6c443d68bcc 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -77,9 +77,7 @@ def __init__( output_size=fd_config.model_config.num_experts, with_bias=False, skip_quant=True, - weight_dtype=( - "float32" if fd_config.load_config.dynamic_load_weight or fd_config.model_config.moe_gate_fp32 else "" - ), + weight_dtype=("float32" if fd_config.model_config.moe_gate_fp32 else ""), ) def forward(self, x, forward_meta):