diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 5757e64a9..46696aa3b 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -186,7 +186,7 @@ class OneCycleLRSchedulerConfig(BaseModel): steps_per_epoch: Optional[Annotated[int, Field(strict=True, gt=0)]] = None pct_start: Annotated[float, Field(strict=True, gt=0.0, le=1.0)] anneal_strategy: str - cycle_momentum: bool = True + cycle_momentum: bool = False base_momentum: Annotated[float, Field(strict=True, gt=0)] | list[ Annotated[float, Field(strict=True, gt=0.0)] ] = 0.85 @@ -227,6 +227,22 @@ class CosineAnnealingLRSchedulerConfig(BaseModel): last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1 +class LinearWarmupCosineAnnealingLRSchedulerConfig(BaseModel): + optimizer: PydanticOptimizerIFType + warmup_steps: Annotated[int, Field(strict=True, gt=0)] + total_steps: Annotated[int, Field(strict=True, gt=0)] + initial_lr: Annotated[float, Field(strict=True, ge=0.0)] + final_lr: Annotated[float, Field(strict=True, ge=0.0)] + max_lr: Annotated[float, Field(strict=True, ge=0.0)] + last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1 + + @model_validator(mode="after") + def check_total_steps_greater_than_warmup_steps(self) -> "LinearWarmupCosineAnnealingLRSchedulerConfig": + if self.total_steps <= self.warmup_steps: + raise ValueError("total_steps must be greater than warmup_steps.") + return self + + class FSDP1CheckpointedOptimizerConfig(BaseModel): checkpoint_loading: PydanticFSDP1CheckpointLoadingIFType checkpoint_path: Path diff --git a/src/modalities/optimizers/lr_schedulers.py b/src/modalities/optimizers/lr_schedulers.py index fca0b2909..c508d2295 100644 --- a/src/modalities/optimizers/lr_schedulers.py +++ b/src/modalities/optimizers/lr_schedulers.py @@ -1,21 +1,64 @@ import warnings -from typing import Optional +from torch import Tensor from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, LRScheduler, SequentialLR class DummyLRScheduler(LRScheduler): - def __init__(self, optimizer: Optimizer, last_epoch: Optional[int] = -1): + def __init__(self, optimizer: Optimizer, last_epoch: int = -1): super().__init__(optimizer, last_epoch) - def get_lr(self) -> list[float]: + def get_lr(self) -> list[float | Tensor]: if not self._get_lr_called_within_step: # type error expected due to internal pytorch implementation warnings.warn( - "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", + UserWarning, ) return [group["lr"] for group in self.optimizer.param_groups] - def _get_closed_form_lr(self) -> list[float]: + def _get_closed_form_lr(self) -> list[float | Tensor]: return self.base_lrs + + +class LRSchedulerFactory: + @staticmethod + def get_linear_warmup_cosine_annealing_lr_scheduler( + optimizer: Optimizer, + warmup_steps: int, + total_steps: int, + initial_lr: float, + final_lr: float, + max_lr: float, + last_epoch: int = -1, + ) -> SequentialLR: + if warmup_steps <= 0: + raise ValueError("warmup_steps must be greater than 0.") + if total_steps <= warmup_steps: + raise ValueError("total_steps must be greater than warmup_steps.") + + if not all(base_lr == max_lr for base_lr in [group["lr"] for group in optimizer.param_groups]): + raise ValueError( + "All parameter groups must have the same initial_lr." + "and it must be equal to the initial_lr passed to the LR scheduler factory." + ) + + warmup_scheduler = LinearLR( + optimizer=optimizer, + start_factor=initial_lr / max_lr, + end_factor=1, + total_iters=warmup_steps, + ) + cosine_scheduler = CosineAnnealingLR( + optimizer=optimizer, + T_max=total_steps - warmup_steps, + eta_min=final_lr, + ) + + return SequentialLR( + optimizer=optimizer, + schedulers=[warmup_scheduler, cosine_scheduler], + milestones=[warmup_steps], + last_epoch=last_epoch, + ) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 67f100f0f..26df9b432 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -48,6 +48,7 @@ GPT2MFUCalculatorConfig, GPT2ModelTPConfig, LinearLRSchedulerConfig, + LinearWarmupCosineAnnealingLRSchedulerConfig, LLMDataLoaderConfig, MemMapDatasetConfig, OneCycleLRSchedulerConfig, @@ -108,7 +109,7 @@ ComposedInitializationRoutines, ComposedModelInitializationConfig, ) -from modalities.optimizers.lr_schedulers import DummyLRScheduler +from modalities.optimizers.lr_schedulers import DummyLRScheduler, LRSchedulerFactory from modalities.optimizers.optimizer_factory import OptimizerFactory from modalities.optimizers.optimizer_list import OptimizersList from modalities.optimizers.scheduler_list import SchedulerList @@ -285,6 +286,12 @@ class ComponentEntity: maybe_optimizer_list(torch.optim.lr_scheduler.CosineAnnealingLR), CosineAnnealingLRSchedulerConfig, ), + ComponentEntity( + "scheduler", + "linear_warmup_cosine_annealing_lr", + maybe_optimizer_list(LRSchedulerFactory.get_linear_warmup_cosine_annealing_lr_scheduler), + LinearWarmupCosineAnnealingLRSchedulerConfig, + ), # tokenizers ComponentEntity("tokenizer", "pretrained_hf_tokenizer", PreTrainedHFTokenizer, PreTrainedHFTokenizerConfig), ComponentEntity("tokenizer", "pretrained_sp_tokenizer", PreTrainedSPTokenizer, PreTrainedSPTokenizerConfig), diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index b1f005940..be50bb6e9 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -1,5 +1,6 @@ from datetime import datetime from enum import Enum +import gc from typing import Callable, Optional import torch @@ -26,6 +27,25 @@ from modalities.utils.typing_utils import FSDPX +class GarbageCollection: + # Some portions of this implementation are inspired, adapted, or refactored + # from Meta's open-source project TorchTitan, + # licensed under the BSD 3-Clause License. + def __init__(self, gc_freq: int = 10): + assert gc_freq > 0, "gc_freq must be a positive integer" + self.gc_freq = gc_freq + gc.disable() + self.collect() # GC invoked here + + def run(self, step_count: int): + if step_count > 1 and step_count % self.gc_freq == 0: + self.collect() # GC invoked here + + @staticmethod + def collect(generation: int = 1): + gc.collect(generation) # GC invoked here + + class ThroughputAggregationKeys(Enum): NUM_SAMPLES = "NUM_SAMPLES" FORWARD_BACKWARD_TIME = "FORWARD_BACKWARD_TIME" @@ -70,6 +90,7 @@ def __init__( Returns: None """ + self.gc = GarbageCollection(gc_freq=10) self.global_rank = global_rank if device_mesh is not None: self.dp_degree = get_parallel_degree( @@ -283,6 +304,7 @@ def train( ) # Check if model performance should be logged if training_progress.num_seen_steps_total % training_log_interval_in_steps == 0 and step_performed: + dist.barrier() forward_backward_time_recorder.stop() forward_backward_time = forward_backward_time_recorder.delta_t forward_backward_time_recorder.reset() @@ -363,8 +385,10 @@ def train( cumulated_losses.zero_() if step_performed: + self.gc.run(step_count=training_progress.num_seen_steps_total) evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total) checkpointing_callback(training_progress=training_progress) + profiler_cm.step() @staticmethod diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index 58793e7ae..c6c12019f 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, call import numpy as np +import torch from modalities.checkpointing.checkpoint_saving import CheckpointSaving from modalities.checkpointing.stateful.app_state import AppState @@ -8,7 +9,7 @@ from modalities.evaluator import Evaluator from modalities.gym import Gym from modalities.loss_functions import Loss -from modalities.optimizers.lr_schedulers import DummyLRScheduler +from modalities.optimizers.lr_schedulers import DummyLRScheduler, LRSchedulerFactory from modalities.trainer import Trainer from tests.utility import configure_dataloader_mock @@ -76,3 +77,27 @@ def test_dummy_lr_scheduler(optimizer_with_param_groups_mock: MagicMock): assert np.allclose(scheduler.get_lr(), [0.08, 0.18, 0.28], atol=1e-6) assert scheduler._get_closed_form_lr() == [0.1, 0.2, 0.3] assert np.allclose(scheduler.get_last_lr(), [0.08, 0.18, 0.28], atol=1e-6) + + +def test_linear_warmup_cosine_annealing_lr_scheduler(): + parameter = torch.nn.Parameter(torch.tensor([1.0])) + optimizer = torch.optim.SGD([parameter], lr=1.0) + scheduler = LRSchedulerFactory.get_linear_warmup_cosine_annealing_lr_scheduler( + optimizer=optimizer, + warmup_steps=2, + total_steps=6, + initial_lr=0.1, + final_lr=0.2, + max_lr=1.0, + ) + + learning_rates = [scheduler.get_last_lr()[0]] + for _ in range(6): + optimizer.step() + scheduler.step() + learning_rates.append(scheduler.get_last_lr()[0]) + + assert learning_rates[0] < learning_rates[1] < learning_rates[2] + assert np.isclose(learning_rates[2], 1.0, atol=1e-6) + assert learning_rates[2] > learning_rates[3] > learning_rates[4] > learning_rates[5] > learning_rates[6] + assert np.isclose(learning_rates[6], 0.2, atol=1e-6)