Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions tests/model/test_fsdp_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.tensor import DTensor

from xtuner._testing.testcase import DeterministicDDPTestCase
from xtuner.v1.config import FSDPConfig
from xtuner.v1.data_proto import SequenceContext
from xtuner.v1.loss.ce_loss import CELossConfig
from xtuner.v1.model.base import BaseModel, ModelOutputs, XTunerBaseModelConfig
from xtuner.v1.module import LMHead


class ToyModelConfig(XTunerBaseModelConfig):
vocab_size: int = 32
hidden_size: int = 16
intermediate_size: int = 24

def build(self) -> "ToyModel":
return ToyModel(self)


class ToyModel(BaseModel):
config: ToyModelConfig

def __init__(self, config: ToyModelConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.lm_head = LMHead(config.intermediate_size, config.vocab_size, bias=False)
self._init_load_spec()

def to_hf_key_list(self, key: str) -> list[str]:
return [key]

def forward(self, seq_ctx: SequenceContext, loss_ctx=None) -> ModelOutputs:
assert seq_ctx.input_ids is not None
hidden_states = self.embed_tokens(seq_ctx.input_ids)
hidden_states = torch.relu(self.fc1(hidden_states))

lm_loss_ctx = loss_ctx["lm"] if loss_ctx is not None else None
loss, (logits, extra_info) = self.lm_head(hidden_states, lm_loss_ctx)
return ModelOutputs(loss=loss, logits=logits, extra_info=extra_info)


class ReferenceToyModel(nn.Module):
def __init__(self, vocab_size: int = 32, hidden_size: int = 16, intermediate_size: int = 24):
super().__init__()
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
self.fc1 = nn.Linear(hidden_size, intermediate_size)
self.lm_head = LMHead(intermediate_size, vocab_size, bias=False)

def forward(self, seq_ctx: SequenceContext, loss_ctx=None) -> ModelOutputs:
assert seq_ctx.input_ids is not None
hidden_states = self.embed_tokens(seq_ctx.input_ids)
hidden_states = torch.relu(self.fc1(hidden_states))

lm_loss_ctx = loss_ctx["lm"] if loss_ctx is not None else None
loss, (logits, extra_info) = self.lm_head(hidden_states, lm_loss_ctx)
return ModelOutputs(loss=loss, logits=logits, extra_info=extra_info)


def _full_tensor(tensor: torch.Tensor) -> torch.Tensor:
return tensor.full_tensor() if isinstance(tensor, DTensor) else tensor


def _build_batch(vocab_size: int, device: str) -> tuple[SequenceContext, dict[str, object]]:
input_ids = torch.randint(0, vocab_size, (1, 9), dtype=torch.int64, device=device)
seq_ctx = SequenceContext.from_input_ids(input_ids=(input_ids[:, :-1],), device=device)

loss_cfg = CELossConfig()
loss_ctx = loss_cfg.build(data={"shifted_labels": input_ids[:, 1:]}, sp_mesh=None)
assert loss_ctx is not None
loss_ctx = loss_cfg.loss_ctx_cls.build_batches([loss_ctx])[0]
return seq_ctx, {"lm": loss_ctx}


class TestFSDPModel(DeterministicDDPTestCase):
@property
def world_size(self) -> int:
return 4

def test_model_forward_backward(self):
self.create_pg("cuda")

torch.manual_seed(0)
device = "cuda"
config = ToyModelConfig(compile_cfg=False)
seq_ctx, loss_ctx = _build_batch(config.vocab_size, device)

ref_model = ReferenceToyModel(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
).to(device)
model = config.build().to(device)
model.load_state_dict(ref_model.state_dict())

ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-3, weight_decay=1e-2)
ref_output = ref_model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)
assert ref_output.loss is not None
assert ref_output.logits is not None
ref_output.loss.backward()
for param in ref_model.parameters():
assert param.grad is not None
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)

fsdp_config = FSDPConfig(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
torch_compile=False,
)
model.fully_shard(fsdp_config=fsdp_config)
fsdp_optim = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

output = model(seq_ctx=seq_ctx, loss_ctx=loss_ctx)
assert output.loss is not None
assert output.logits is not None
output.loss.backward()

torch.testing.assert_close(output.loss, ref_output.loss)
torch.testing.assert_close(output.logits, ref_output.logits)

ref_optim.step()
fsdp_optim.step()

for name, ref_param in ref_model.state_dict().items():
torch.testing.assert_close(_full_tensor(model.state_dict()[name]), ref_param)
89 changes: 89 additions & 0 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import json
import math
import pydoc
Expand Down Expand Up @@ -28,6 +29,7 @@
)
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, distribute_tensor
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
from torch.utils import _pytree
from typing_extensions import NotRequired, Self, TypedDict, overload

from transformers.configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -243,6 +245,93 @@ def __getitem__(self, key):
def __contains__(self, key):
return key in self.model_fields_set

@classmethod
def __init_subclass__(cls, **kwargs: Any) -> None:
# Automatically register every subclass as a pytree node so that
# FSDP can traverse the output tensors and insert pre_backward_hooks.
super().__init_subclass__(**kwargs)
cls._register_pytree_node()

@staticmethod
def _model_field_names(model_type: type[PydanticBaseModel]) -> list[str]:
return list(model_type.model_fields)

@staticmethod
def _flatten_pydantic_model(
model: PydanticBaseModel,
) -> tuple[list[Any], tuple[type[PydanticBaseModel], list[str]]]:
# Flatten the model into a list of field values (the "leaves") plus a
# context tuple that carries enough information to reconstruct it.
field_names = ModelOutputs._model_field_names(type(model))
children = [getattr(model, field_name) for field_name in field_names]
return children, (type(model), field_names)

@staticmethod
def _unflatten_pydantic_model(
children: Iterable[Any],
context: tuple[type[PydanticBaseModel], list[str]],
) -> PydanticBaseModel:
# Reconstruct the model from the (possibly transformed) leaf values.
# model_construct is used to bypass Pydantic validation, which is safe
# here because the values were produced by the flatten step above.
model_type, field_names = context
values = dict(zip(field_names, children, strict=True))
return model_type.model_construct(**values)

@staticmethod
def _flatten_pydantic_model_with_keys(
model: PydanticBaseModel,
) -> tuple[list[tuple[_pytree.KeyEntry, Any]], tuple[type[PydanticBaseModel], list[str]]]:
# Same as _flatten_pydantic_model but pairs each leaf with a KeyEntry
# so that pytree-aware tools (e.g. torch.export) can emit human-readable
# paths like "logits" instead of bare integer indices.
field_names = ModelOutputs._model_field_names(type(model))
key_children: list[tuple[_pytree.KeyEntry, Any]] = [
(_pytree.GetAttrKey(field_name), getattr(model, field_name)) for field_name in field_names
]
return key_children, (type(model), field_names)

@staticmethod
def _to_dumpable_context(context: tuple[type[PydanticBaseModel], list[str]]) -> dict[str, Any]:
# Serialize the context to a JSON-compatible dict so that the pytree
# structure can be saved (e.g. for torch.export / torch.compile cache).
model_type, field_names = context
return {
"module": model_type.__module__,
"qualname": model_type.__qualname__,
"field_names": field_names,
}

@staticmethod
def _from_dumpable_context(context: dict[str, Any]) -> tuple[type[PydanticBaseModel], list[str]]:
# Deserialize the context produced by _to_dumpable_context by
# dynamically importing the model class from its module + qualname.
module = importlib.import_module(context["module"])
model_type: Any = module
for attr in context["qualname"].split("."):
model_type = getattr(model_type, attr)
return model_type, list(context["field_names"])

@classmethod
def _register_pytree_node(cls) -> None:
# Guard against double-registration (e.g. when the module is reloaded).
if cls in _pytree.SUPPORTED_NODES:
return

_pytree.register_pytree_node(
cls,
cls._flatten_pydantic_model,
cls._unflatten_pydantic_model,
serialized_type_name=f"{cls.__module__}.{cls.__qualname__}",
to_dumpable_context=cls._to_dumpable_context,
from_dumpable_context=cls._from_dumpable_context,
flatten_with_keys_fn=cls._flatten_pydantic_model_with_keys,
)


# Register the base class itself; subclasses are handled by __init_subclass__.
ModelOutputs._register_pytree_node()


def _is_float8_available():
# Float8 is only supported on SM89 or later (H100+ GPUs)
Expand Down
Loading