Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 59 additions & 16 deletions slime/backends/megatron_utils/megatron_to_hf/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,40 @@

import torch

_expert_tensor_cache = {}


def _interleave_gate_up(param):
gate, up = param.chunk(2, dim=0)
return torch.stack([gate, up], dim=1).reshape(-1, *param.shape[1:]).contiguous()


def _collect_fused_expert_tensor(args, hf_name, expert_idx, param):
num_experts = getattr(args, "num_experts", None)
if num_experts is None:
raise ValueError("args.num_experts is required to fuse GPT-OSS expert tensors")

num_experts = int(num_experts)
expert_idx = int(expert_idx)
if expert_idx < 0 or expert_idx >= num_experts:
raise ValueError(f"GPT-OSS expert index {expert_idx} is out of range for {num_experts} experts")

bucket = _expert_tensor_cache.setdefault(hf_name, {})
if expert_idx in bucket:
raise ValueError(f"Duplicate GPT-OSS expert tensor for {hf_name} expert {expert_idx}")
bucket[expert_idx] = param

if len(bucket) != num_experts:
return []

missing = [i for i in range(num_experts) if i not in bucket]
if missing:
raise ValueError(f"Missing GPT-OSS expert tensors for {hf_name}: {missing}")

fused = torch.stack([bucket[i] for i in range(num_experts)], dim=0).contiguous()
del _expert_tensor_cache[hf_name]
return [(hf_name, fused)]


def convert_gpt_oss_to_hf(args, name, param):
"""Convert Megatron GPT-OSS parameter names to HF format for weight update to SGLang."""
Expand All @@ -27,15 +61,20 @@ def convert_gpt_oss_to_hf(args, name, param):
if match:
rest, expert_idx = match.groups()
if rest == "linear_fc1":
gate_weight, up_weight = param.chunk(2, dim=0)
return [
(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight", gate_weight),
(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight", up_weight),
]
param = _interleave_gate_up(param).transpose(0, 1).contiguous()
return _collect_fused_expert_tensor(
args,
f"model.layers.{layer_idx}.mlp.experts.gate_up_proj",
expert_idx,
param,
)
elif rest == "linear_fc2":
return [
(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight", param),
]
return _collect_fused_expert_tensor(
args,
f"model.layers.{layer_idx}.mlp.experts.down_proj",
expert_idx,
param.transpose(0, 1).contiguous(),
)
else:
raise ValueError(f"Unknown expert parameter name: {name}")

Expand All @@ -45,15 +84,19 @@ def convert_gpt_oss_to_hf(args, name, param):
if match:
rest, expert_idx = match.groups()
if rest == "linear_fc1":
gate_bias, up_bias = param.chunk(2, dim=0)
return [
(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.bias", gate_bias),
(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.bias", up_bias),
]
return _collect_fused_expert_tensor(
args,
f"model.layers.{layer_idx}.mlp.experts.gate_up_proj_bias",
expert_idx,
_interleave_gate_up(param),
)
elif rest == "linear_fc2":
return [
(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.bias", param),
]
return _collect_fused_expert_tensor(
args,
f"model.layers.{layer_idx}.mlp.experts.down_proj_bias",
expert_idx,
param.contiguous(),
)
else:
raise ValueError(f"Unknown expert bias parameter name: {name}")

Expand Down
124 changes: 124 additions & 0 deletions tests/test_gpt_oss_raw_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import importlib.util
import sys
import types
from pathlib import Path

import pytest
import torch


NUM_GPUS = 0


def load_raw_export_module():
module_path = (
Path(__file__).resolve().parents[1] / "slime" / "backends" / "megatron_utils" / "megatron_to_hf" / "gpt_oss.py"
)
module_name = "test_gpt_oss_raw_export_module"
sys.modules.pop(module_name, None)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module


def convert_args():
return types.SimpleNamespace(
hidden_size=6,
kv_channels=None,
num_attention_heads=2,
num_query_groups=1,
num_experts=2,
)


def convert_expert(module, args, name, param, expert_idx):
return module.convert_gpt_oss_to_hf(
args,
f"module.module.decoder.layers.3.mlp.experts.{name}{expert_idx}",
param,
)


def interleave_gate_up(param):
gate, up = param.chunk(2, dim=0)
return torch.stack([gate, up], dim=1).reshape(-1, *param.shape[1:]).contiguous()


@pytest.mark.unit
def test_gpt_oss_raw_fc1_weight_fuses_experts_for_sglang_gate_up_proj():
module = load_raw_export_module()
args = convert_args()
expert0 = torch.arange(24, dtype=torch.float32).view(4, 6)
expert1 = expert0 + 100

assert convert_expert(module, args, "linear_fc1.weight", expert1, 1) == []
converted = convert_expert(module, args, "linear_fc1.weight", expert0, 0)

expected = torch.stack(
[
interleave_gate_up(expert0).transpose(0, 1).contiguous(),
interleave_gate_up(expert1).transpose(0, 1).contiguous(),
],
dim=0,
)
assert len(converted) == 1
hf_name, hf_weight = converted[0]
assert hf_name == "model.layers.3.mlp.experts.gate_up_proj"
assert torch.equal(hf_weight, expected)


@pytest.mark.unit
def test_gpt_oss_raw_fc2_weight_transposes_and_fuses_experts_for_sglang_down_proj():
module = load_raw_export_module()
args = convert_args()
expert0 = torch.arange(18, dtype=torch.float32).view(6, 3)
expert1 = expert0 + 100

assert convert_expert(module, args, "linear_fc2.weight", expert0, 0) == []
converted = convert_expert(module, args, "linear_fc2.weight", expert1, 1)

expected = torch.stack([expert0.transpose(0, 1).contiguous(), expert1.transpose(0, 1).contiguous()], dim=0)
assert len(converted) == 1
hf_name, hf_weight = converted[0]
assert hf_name == "model.layers.3.mlp.experts.down_proj"
assert torch.equal(hf_weight, expected)


@pytest.mark.unit
def test_gpt_oss_raw_fc1_bias_fuses_interleaved_gate_up_biases():
module = load_raw_export_module()
args = convert_args()
expert0 = torch.arange(4, dtype=torch.float32)
expert1 = expert0 + 100

assert convert_expert(module, args, "linear_fc1.bias", expert0, 0) == []
converted = convert_expert(module, args, "linear_fc1.bias", expert1, 1)

expected = torch.stack([interleave_gate_up(expert0), interleave_gate_up(expert1)], dim=0)
assert len(converted) == 1
hf_name, hf_weight = converted[0]
assert hf_name == "model.layers.3.mlp.experts.gate_up_proj_bias"
assert torch.equal(hf_weight, expected)


@pytest.mark.unit
def test_gpt_oss_raw_fc2_bias_fuses_down_proj_biases():
module = load_raw_export_module()
args = convert_args()
expert0 = torch.arange(3, dtype=torch.float32)
expert1 = expert0 + 100

assert convert_expert(module, args, "linear_fc2.bias", expert1, 1) == []
converted = convert_expert(module, args, "linear_fc2.bias", expert0, 0)

expected = torch.stack([expert0, expert1], dim=0)
assert len(converted) == 1
hf_name, hf_weight = converted[0]
assert hf_name == "model.layers.3.mlp.experts.down_proj_bias"
assert torch.equal(hf_weight, expected)


if __name__ == "__main__":
raise SystemExit(pytest.main([__file__]))
Loading