From 97d0cfc86aaa1e90c70f9d71afbf0c157a542695 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 20 Mar 2026 15:16:38 +0000 Subject: [PATCH 1/5] chore: temporarily reverted combination of fsdp units --- src/modalities/config/config.py | 2 +- src/modalities/models/model_factory.py | 39 ++++++-------------------- 2 files changed, 10 insertions(+), 31 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 5757e64a9..7e7ebbf69 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -275,7 +275,7 @@ class FSDP2WrappedModelConfig(BaseModel): mixed_precision_settings: FSDP2MixedPrecisionSettings reshard_after_forward: bool = True device_mesh: PydanticDeviceMeshIFType - layers_per_fsdp_unit: int = 1 + # layers_per_fsdp_unit: int = 1 @model_validator(mode="after") def validate_mixed_precision_settings(self): diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 142aef920..c0b52b859 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -1,7 +1,3 @@ -# Some portions of this implementation are inspired, adapted, or refactored -# from Meta's open-source project TorchTitan, -# licensed under the BSD 3-Clause License. - import itertools import json import time @@ -172,7 +168,6 @@ def get_fsdp2_wrapped_model( device_mesh: DeviceMesh, mixed_precision_settings: FSDP2MixedPrecisionSettings, reshard_after_forward: bool, - layers_per_fsdp_unit: int = 1, ) -> FSDP2: """Get the FSDP2-wrapped model. @@ -186,7 +181,6 @@ def get_fsdp2_wrapped_model( device_mesh (DeviceMesh): The device mesh. mixed_precision_settings (FSDP2MixedPrecisionSettings): Mixed precision settings. reshard_after_forward (bool): Whether to reshard after forward. - layers_per_fsdp_unit (int): Number of layers per FSDP unit. Default is 1. Returns: FSDP2: The FSDP2-wrapped model. @@ -211,32 +205,17 @@ def get_fsdp2_wrapped_model( fsdp_config = {"mesh": device_mesh[fsdp2_degrees], "mp_policy": mp_policy} modules = list(model.modules()) - # we first shard all the blocks - grouped_modules: list[nn.Module] = [] - module_id = 0 for module_id, module in enumerate(modules): if isinstance(module, block_types): - grouped_modules.append(module) - if len(grouped_modules) == layers_per_fsdp_unit: - # As an optimization, we do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately. - reshard_block_after_forward = reshard_after_forward and int(module_id) < len(modules) - 1 - fully_shard( - grouped_modules, - **fsdp_config, - reshard_after_forward=reshard_block_after_forward, - ) - grouped_modules = list() - - if len(grouped_modules) > 0: - reshard_block_after_forward = False - fully_shard( - grouped_modules, - **fsdp_config, - reshard_after_forward=reshard_block_after_forward, - ) - + # As an optimization, we do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately. + reshard_block_after_forward = reshard_after_forward and int(module_id) < len(modules) - 1 + fully_shard( + module, + **fsdp_config, + reshard_after_forward=reshard_block_after_forward, + ) # finally, we shard the entire model fully_shard(model, **fsdp_config, reshard_after_forward=reshard_after_forward) logger.info( @@ -763,4 +742,4 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) parallelize_plan=transformer_block_tp_plan, ) - return model + return model \ No newline at end of file From 40b2de2a9646101e87c55543d4dd49102ec45d18 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Thu, 26 Mar 2026 18:15:45 +0000 Subject: [PATCH 2/5] feat: added garbage collection --- src/modalities/trainer.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index b1f005940..be50bb6e9 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -1,5 +1,6 @@ from datetime import datetime from enum import Enum +import gc from typing import Callable, Optional import torch @@ -26,6 +27,25 @@ from modalities.utils.typing_utils import FSDPX +class GarbageCollection: + # Some portions of this implementation are inspired, adapted, or refactored + # from Meta's open-source project TorchTitan, + # licensed under the BSD 3-Clause License. + def __init__(self, gc_freq: int = 10): + assert gc_freq > 0, "gc_freq must be a positive integer" + self.gc_freq = gc_freq + gc.disable() + self.collect() # GC invoked here + + def run(self, step_count: int): + if step_count > 1 and step_count % self.gc_freq == 0: + self.collect() # GC invoked here + + @staticmethod + def collect(generation: int = 1): + gc.collect(generation) # GC invoked here + + class ThroughputAggregationKeys(Enum): NUM_SAMPLES = "NUM_SAMPLES" FORWARD_BACKWARD_TIME = "FORWARD_BACKWARD_TIME" @@ -70,6 +90,7 @@ def __init__( Returns: None """ + self.gc = GarbageCollection(gc_freq=10) self.global_rank = global_rank if device_mesh is not None: self.dp_degree = get_parallel_degree( @@ -283,6 +304,7 @@ def train( ) # Check if model performance should be logged if training_progress.num_seen_steps_total % training_log_interval_in_steps == 0 and step_performed: + dist.barrier() forward_backward_time_recorder.stop() forward_backward_time = forward_backward_time_recorder.delta_t forward_backward_time_recorder.reset() @@ -363,8 +385,10 @@ def train( cumulated_losses.zero_() if step_performed: + self.gc.run(step_count=training_progress.num_seen_steps_total) evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total) checkpointing_callback(training_progress=training_progress) + profiler_cm.step() @staticmethod From 534bb46132a10ab6512ee317a4c7164cf8874fe0 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Thu, 26 Mar 2026 18:16:41 +0000 Subject: [PATCH 3/5] chore: setting cycle_momentum to False in OneCycleLRSchedulerConfig --- src/modalities/config/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 7e7ebbf69..4fead1563 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -186,7 +186,7 @@ class OneCycleLRSchedulerConfig(BaseModel): steps_per_epoch: Optional[Annotated[int, Field(strict=True, gt=0)]] = None pct_start: Annotated[float, Field(strict=True, gt=0.0, le=1.0)] anneal_strategy: str - cycle_momentum: bool = True + cycle_momentum: bool = False base_momentum: Annotated[float, Field(strict=True, gt=0)] | list[ Annotated[float, Field(strict=True, gt=0.0)] ] = 0.85 @@ -275,7 +275,7 @@ class FSDP2WrappedModelConfig(BaseModel): mixed_precision_settings: FSDP2MixedPrecisionSettings reshard_after_forward: bool = True device_mesh: PydanticDeviceMeshIFType - # layers_per_fsdp_unit: int = 1 + layers_per_fsdp_unit: int = 1 @model_validator(mode="after") def validate_mixed_precision_settings(self): From 20c678e3d12ffe917c3c74e7ad7b2bf1747e889a Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Thu, 26 Mar 2026 18:17:05 +0000 Subject: [PATCH 4/5] refactor: reintroducing layers_per_fsdp_unit in model_factory for better FSDP sharding control --- src/modalities/models/model_factory.py | 39 ++++++++++++++++++++------ 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index c0b52b859..142aef920 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -1,3 +1,7 @@ +# Some portions of this implementation are inspired, adapted, or refactored +# from Meta's open-source project TorchTitan, +# licensed under the BSD 3-Clause License. + import itertools import json import time @@ -168,6 +172,7 @@ def get_fsdp2_wrapped_model( device_mesh: DeviceMesh, mixed_precision_settings: FSDP2MixedPrecisionSettings, reshard_after_forward: bool, + layers_per_fsdp_unit: int = 1, ) -> FSDP2: """Get the FSDP2-wrapped model. @@ -181,6 +186,7 @@ def get_fsdp2_wrapped_model( device_mesh (DeviceMesh): The device mesh. mixed_precision_settings (FSDP2MixedPrecisionSettings): Mixed precision settings. reshard_after_forward (bool): Whether to reshard after forward. + layers_per_fsdp_unit (int): Number of layers per FSDP unit. Default is 1. Returns: FSDP2: The FSDP2-wrapped model. @@ -205,17 +211,32 @@ def get_fsdp2_wrapped_model( fsdp_config = {"mesh": device_mesh[fsdp2_degrees], "mp_policy": mp_policy} modules = list(model.modules()) + # we first shard all the blocks + grouped_modules: list[nn.Module] = [] + module_id = 0 for module_id, module in enumerate(modules): if isinstance(module, block_types): - # As an optimization, we do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately. - reshard_block_after_forward = reshard_after_forward and int(module_id) < len(modules) - 1 - fully_shard( - module, - **fsdp_config, - reshard_after_forward=reshard_block_after_forward, - ) + grouped_modules.append(module) + if len(grouped_modules) == layers_per_fsdp_unit: + # As an optimization, we do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately. + reshard_block_after_forward = reshard_after_forward and int(module_id) < len(modules) - 1 + fully_shard( + grouped_modules, + **fsdp_config, + reshard_after_forward=reshard_block_after_forward, + ) + grouped_modules = list() + + if len(grouped_modules) > 0: + reshard_block_after_forward = False + fully_shard( + grouped_modules, + **fsdp_config, + reshard_after_forward=reshard_block_after_forward, + ) + # finally, we shard the entire model fully_shard(model, **fsdp_config, reshard_after_forward=reshard_after_forward) logger.info( @@ -742,4 +763,4 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) parallelize_plan=transformer_block_tp_plan, ) - return model \ No newline at end of file + return model From d7b1513d6ed6e9d232f2d8d58634fa01f0352852 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 27 Mar 2026 18:57:10 +0100 Subject: [PATCH 5/5] feat: added LinearWarmupCosineAnnealingLRScheduler --- src/modalities/config/config.py | 16 +++++++ src/modalities/optimizers/lr_schedulers.py | 55 +++++++++++++++++++--- src/modalities/registry/components.py | 9 +++- tests/test_lr_scheduler.py | 27 ++++++++++- 4 files changed, 99 insertions(+), 8 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 4fead1563..46696aa3b 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -227,6 +227,22 @@ class CosineAnnealingLRSchedulerConfig(BaseModel): last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1 +class LinearWarmupCosineAnnealingLRSchedulerConfig(BaseModel): + optimizer: PydanticOptimizerIFType + warmup_steps: Annotated[int, Field(strict=True, gt=0)] + total_steps: Annotated[int, Field(strict=True, gt=0)] + initial_lr: Annotated[float, Field(strict=True, ge=0.0)] + final_lr: Annotated[float, Field(strict=True, ge=0.0)] + max_lr: Annotated[float, Field(strict=True, ge=0.0)] + last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1 + + @model_validator(mode="after") + def check_total_steps_greater_than_warmup_steps(self) -> "LinearWarmupCosineAnnealingLRSchedulerConfig": + if self.total_steps <= self.warmup_steps: + raise ValueError("total_steps must be greater than warmup_steps.") + return self + + class FSDP1CheckpointedOptimizerConfig(BaseModel): checkpoint_loading: PydanticFSDP1CheckpointLoadingIFType checkpoint_path: Path diff --git a/src/modalities/optimizers/lr_schedulers.py b/src/modalities/optimizers/lr_schedulers.py index fca0b2909..c508d2295 100644 --- a/src/modalities/optimizers/lr_schedulers.py +++ b/src/modalities/optimizers/lr_schedulers.py @@ -1,21 +1,64 @@ import warnings -from typing import Optional +from torch import Tensor from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, LRScheduler, SequentialLR class DummyLRScheduler(LRScheduler): - def __init__(self, optimizer: Optimizer, last_epoch: Optional[int] = -1): + def __init__(self, optimizer: Optimizer, last_epoch: int = -1): super().__init__(optimizer, last_epoch) - def get_lr(self) -> list[float]: + def get_lr(self) -> list[float | Tensor]: if not self._get_lr_called_within_step: # type error expected due to internal pytorch implementation warnings.warn( - "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", + UserWarning, ) return [group["lr"] for group in self.optimizer.param_groups] - def _get_closed_form_lr(self) -> list[float]: + def _get_closed_form_lr(self) -> list[float | Tensor]: return self.base_lrs + + +class LRSchedulerFactory: + @staticmethod + def get_linear_warmup_cosine_annealing_lr_scheduler( + optimizer: Optimizer, + warmup_steps: int, + total_steps: int, + initial_lr: float, + final_lr: float, + max_lr: float, + last_epoch: int = -1, + ) -> SequentialLR: + if warmup_steps <= 0: + raise ValueError("warmup_steps must be greater than 0.") + if total_steps <= warmup_steps: + raise ValueError("total_steps must be greater than warmup_steps.") + + if not all(base_lr == max_lr for base_lr in [group["lr"] for group in optimizer.param_groups]): + raise ValueError( + "All parameter groups must have the same initial_lr." + "and it must be equal to the initial_lr passed to the LR scheduler factory." + ) + + warmup_scheduler = LinearLR( + optimizer=optimizer, + start_factor=initial_lr / max_lr, + end_factor=1, + total_iters=warmup_steps, + ) + cosine_scheduler = CosineAnnealingLR( + optimizer=optimizer, + T_max=total_steps - warmup_steps, + eta_min=final_lr, + ) + + return SequentialLR( + optimizer=optimizer, + schedulers=[warmup_scheduler, cosine_scheduler], + milestones=[warmup_steps], + last_epoch=last_epoch, + ) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 67f100f0f..26df9b432 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -48,6 +48,7 @@ GPT2MFUCalculatorConfig, GPT2ModelTPConfig, LinearLRSchedulerConfig, + LinearWarmupCosineAnnealingLRSchedulerConfig, LLMDataLoaderConfig, MemMapDatasetConfig, OneCycleLRSchedulerConfig, @@ -108,7 +109,7 @@ ComposedInitializationRoutines, ComposedModelInitializationConfig, ) -from modalities.optimizers.lr_schedulers import DummyLRScheduler +from modalities.optimizers.lr_schedulers import DummyLRScheduler, LRSchedulerFactory from modalities.optimizers.optimizer_factory import OptimizerFactory from modalities.optimizers.optimizer_list import OptimizersList from modalities.optimizers.scheduler_list import SchedulerList @@ -285,6 +286,12 @@ class ComponentEntity: maybe_optimizer_list(torch.optim.lr_scheduler.CosineAnnealingLR), CosineAnnealingLRSchedulerConfig, ), + ComponentEntity( + "scheduler", + "linear_warmup_cosine_annealing_lr", + maybe_optimizer_list(LRSchedulerFactory.get_linear_warmup_cosine_annealing_lr_scheduler), + LinearWarmupCosineAnnealingLRSchedulerConfig, + ), # tokenizers ComponentEntity("tokenizer", "pretrained_hf_tokenizer", PreTrainedHFTokenizer, PreTrainedHFTokenizerConfig), ComponentEntity("tokenizer", "pretrained_sp_tokenizer", PreTrainedSPTokenizer, PreTrainedSPTokenizerConfig), diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index 58793e7ae..c6c12019f 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, call import numpy as np +import torch from modalities.checkpointing.checkpoint_saving import CheckpointSaving from modalities.checkpointing.stateful.app_state import AppState @@ -8,7 +9,7 @@ from modalities.evaluator import Evaluator from modalities.gym import Gym from modalities.loss_functions import Loss -from modalities.optimizers.lr_schedulers import DummyLRScheduler +from modalities.optimizers.lr_schedulers import DummyLRScheduler, LRSchedulerFactory from modalities.trainer import Trainer from tests.utility import configure_dataloader_mock @@ -76,3 +77,27 @@ def test_dummy_lr_scheduler(optimizer_with_param_groups_mock: MagicMock): assert np.allclose(scheduler.get_lr(), [0.08, 0.18, 0.28], atol=1e-6) assert scheduler._get_closed_form_lr() == [0.1, 0.2, 0.3] assert np.allclose(scheduler.get_last_lr(), [0.08, 0.18, 0.28], atol=1e-6) + + +def test_linear_warmup_cosine_annealing_lr_scheduler(): + parameter = torch.nn.Parameter(torch.tensor([1.0])) + optimizer = torch.optim.SGD([parameter], lr=1.0) + scheduler = LRSchedulerFactory.get_linear_warmup_cosine_annealing_lr_scheduler( + optimizer=optimizer, + warmup_steps=2, + total_steps=6, + initial_lr=0.1, + final_lr=0.2, + max_lr=1.0, + ) + + learning_rates = [scheduler.get_last_lr()[0]] + for _ in range(6): + optimizer.step() + scheduler.step() + learning_rates.append(scheduler.get_last_lr()[0]) + + assert learning_rates[0] < learning_rates[1] < learning_rates[2] + assert np.isclose(learning_rates[2], 1.0, atol=1e-6) + assert learning_rates[2] > learning_rates[3] > learning_rates[4] > learning_rates[5] > learning_rates[6] + assert np.isclose(learning_rates[6], 0.2, atol=1e-6)