From aa7c88ec9416d3f03cd2de25bd6bf89c0972441d Mon Sep 17 00:00:00 2001 From: Jiang020609 <190608333+Jiang020609@users.noreply.github.com> Date: Mon, 1 Jun 2026 23:41:56 +0800 Subject: [PATCH] fix(gpt-oss): fuse raw expert tensors for sglang --- .../megatron_utils/megatron_to_hf/gpt_oss.py | 75 ++++++++--- tests/test_gpt_oss_raw_converter.py | 124 ++++++++++++++++++ 2 files changed, 183 insertions(+), 16 deletions(-) create mode 100644 tests/test_gpt_oss_raw_converter.py diff --git a/slime/backends/megatron_utils/megatron_to_hf/gpt_oss.py b/slime/backends/megatron_utils/megatron_to_hf/gpt_oss.py index b90507f04b..b4510b7277 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/gpt_oss.py +++ b/slime/backends/megatron_utils/megatron_to_hf/gpt_oss.py @@ -2,6 +2,40 @@ import torch +_expert_tensor_cache = {} + + +def _interleave_gate_up(param): + gate, up = param.chunk(2, dim=0) + return torch.stack([gate, up], dim=1).reshape(-1, *param.shape[1:]).contiguous() + + +def _collect_fused_expert_tensor(args, hf_name, expert_idx, param): + num_experts = getattr(args, "num_experts", None) + if num_experts is None: + raise ValueError("args.num_experts is required to fuse GPT-OSS expert tensors") + + num_experts = int(num_experts) + expert_idx = int(expert_idx) + if expert_idx < 0 or expert_idx >= num_experts: + raise ValueError(f"GPT-OSS expert index {expert_idx} is out of range for {num_experts} experts") + + bucket = _expert_tensor_cache.setdefault(hf_name, {}) + if expert_idx in bucket: + raise ValueError(f"Duplicate GPT-OSS expert tensor for {hf_name} expert {expert_idx}") + bucket[expert_idx] = param + + if len(bucket) != num_experts: + return [] + + missing = [i for i in range(num_experts) if i not in bucket] + if missing: + raise ValueError(f"Missing GPT-OSS expert tensors for {hf_name}: {missing}") + + fused = torch.stack([bucket[i] for i in range(num_experts)], dim=0).contiguous() + del _expert_tensor_cache[hf_name] + return [(hf_name, fused)] + def convert_gpt_oss_to_hf(args, name, param): """Convert Megatron GPT-OSS parameter names to HF format for weight update to SGLang.""" @@ -27,15 +61,20 @@ def convert_gpt_oss_to_hf(args, name, param): if match: rest, expert_idx = match.groups() if rest == "linear_fc1": - gate_weight, up_weight = param.chunk(2, dim=0) - return [ - (f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight", gate_weight), - (f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight", up_weight), - ] + param = _interleave_gate_up(param).transpose(0, 1).contiguous() + return _collect_fused_expert_tensor( + args, + f"model.layers.{layer_idx}.mlp.experts.gate_up_proj", + expert_idx, + param, + ) elif rest == "linear_fc2": - return [ - (f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight", param), - ] + return _collect_fused_expert_tensor( + args, + f"model.layers.{layer_idx}.mlp.experts.down_proj", + expert_idx, + param.transpose(0, 1).contiguous(), + ) else: raise ValueError(f"Unknown expert parameter name: {name}") @@ -45,15 +84,19 @@ def convert_gpt_oss_to_hf(args, name, param): if match: rest, expert_idx = match.groups() if rest == "linear_fc1": - gate_bias, up_bias = param.chunk(2, dim=0) - return [ - (f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.bias", gate_bias), - (f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.bias", up_bias), - ] + return _collect_fused_expert_tensor( + args, + f"model.layers.{layer_idx}.mlp.experts.gate_up_proj_bias", + expert_idx, + _interleave_gate_up(param), + ) elif rest == "linear_fc2": - return [ - (f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.bias", param), - ] + return _collect_fused_expert_tensor( + args, + f"model.layers.{layer_idx}.mlp.experts.down_proj_bias", + expert_idx, + param.contiguous(), + ) else: raise ValueError(f"Unknown expert bias parameter name: {name}") diff --git a/tests/test_gpt_oss_raw_converter.py b/tests/test_gpt_oss_raw_converter.py new file mode 100644 index 0000000000..de92952ea7 --- /dev/null +++ b/tests/test_gpt_oss_raw_converter.py @@ -0,0 +1,124 @@ +import importlib.util +import sys +import types +from pathlib import Path + +import pytest +import torch + + +NUM_GPUS = 0 + + +def load_raw_export_module(): + module_path = ( + Path(__file__).resolve().parents[1] / "slime" / "backends" / "megatron_utils" / "megatron_to_hf" / "gpt_oss.py" + ) + module_name = "test_gpt_oss_raw_export_module" + sys.modules.pop(module_name, None) + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def convert_args(): + return types.SimpleNamespace( + hidden_size=6, + kv_channels=None, + num_attention_heads=2, + num_query_groups=1, + num_experts=2, + ) + + +def convert_expert(module, args, name, param, expert_idx): + return module.convert_gpt_oss_to_hf( + args, + f"module.module.decoder.layers.3.mlp.experts.{name}{expert_idx}", + param, + ) + + +def interleave_gate_up(param): + gate, up = param.chunk(2, dim=0) + return torch.stack([gate, up], dim=1).reshape(-1, *param.shape[1:]).contiguous() + + +@pytest.mark.unit +def test_gpt_oss_raw_fc1_weight_fuses_experts_for_sglang_gate_up_proj(): + module = load_raw_export_module() + args = convert_args() + expert0 = torch.arange(24, dtype=torch.float32).view(4, 6) + expert1 = expert0 + 100 + + assert convert_expert(module, args, "linear_fc1.weight", expert1, 1) == [] + converted = convert_expert(module, args, "linear_fc1.weight", expert0, 0) + + expected = torch.stack( + [ + interleave_gate_up(expert0).transpose(0, 1).contiguous(), + interleave_gate_up(expert1).transpose(0, 1).contiguous(), + ], + dim=0, + ) + assert len(converted) == 1 + hf_name, hf_weight = converted[0] + assert hf_name == "model.layers.3.mlp.experts.gate_up_proj" + assert torch.equal(hf_weight, expected) + + +@pytest.mark.unit +def test_gpt_oss_raw_fc2_weight_transposes_and_fuses_experts_for_sglang_down_proj(): + module = load_raw_export_module() + args = convert_args() + expert0 = torch.arange(18, dtype=torch.float32).view(6, 3) + expert1 = expert0 + 100 + + assert convert_expert(module, args, "linear_fc2.weight", expert0, 0) == [] + converted = convert_expert(module, args, "linear_fc2.weight", expert1, 1) + + expected = torch.stack([expert0.transpose(0, 1).contiguous(), expert1.transpose(0, 1).contiguous()], dim=0) + assert len(converted) == 1 + hf_name, hf_weight = converted[0] + assert hf_name == "model.layers.3.mlp.experts.down_proj" + assert torch.equal(hf_weight, expected) + + +@pytest.mark.unit +def test_gpt_oss_raw_fc1_bias_fuses_interleaved_gate_up_biases(): + module = load_raw_export_module() + args = convert_args() + expert0 = torch.arange(4, dtype=torch.float32) + expert1 = expert0 + 100 + + assert convert_expert(module, args, "linear_fc1.bias", expert0, 0) == [] + converted = convert_expert(module, args, "linear_fc1.bias", expert1, 1) + + expected = torch.stack([interleave_gate_up(expert0), interleave_gate_up(expert1)], dim=0) + assert len(converted) == 1 + hf_name, hf_weight = converted[0] + assert hf_name == "model.layers.3.mlp.experts.gate_up_proj_bias" + assert torch.equal(hf_weight, expected) + + +@pytest.mark.unit +def test_gpt_oss_raw_fc2_bias_fuses_down_proj_biases(): + module = load_raw_export_module() + args = convert_args() + expert0 = torch.arange(3, dtype=torch.float32) + expert1 = expert0 + 100 + + assert convert_expert(module, args, "linear_fc2.bias", expert1, 1) == [] + converted = convert_expert(module, args, "linear_fc2.bias", expert0, 0) + + expected = torch.stack([expert0, expert1], dim=0) + assert len(converted) == 1 + hf_name, hf_weight = converted[0] + assert hf_name == "model.layers.3.mlp.experts.down_proj_bias" + assert torch.equal(hf_weight, expected) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__]))