Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class OneCycleLRSchedulerConfig(BaseModel):
steps_per_epoch: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
pct_start: Annotated[float, Field(strict=True, gt=0.0, le=1.0)]
anneal_strategy: str
cycle_momentum: bool = True
cycle_momentum: bool = False
base_momentum: Annotated[float, Field(strict=True, gt=0)] | list[
Annotated[float, Field(strict=True, gt=0.0)]
] = 0.85
Expand Down Expand Up @@ -227,6 +227,22 @@ class CosineAnnealingLRSchedulerConfig(BaseModel):
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1


class LinearWarmupCosineAnnealingLRSchedulerConfig(BaseModel):
optimizer: PydanticOptimizerIFType
warmup_steps: Annotated[int, Field(strict=True, gt=0)]
total_steps: Annotated[int, Field(strict=True, gt=0)]
initial_lr: Annotated[float, Field(strict=True, ge=0.0)]
final_lr: Annotated[float, Field(strict=True, ge=0.0)]
max_lr: Annotated[float, Field(strict=True, ge=0.0)]
last_epoch: Annotated[int, Field(strict=True, ge=-1)] = -1

@model_validator(mode="after")
def check_total_steps_greater_than_warmup_steps(self) -> "LinearWarmupCosineAnnealingLRSchedulerConfig":
if self.total_steps <= self.warmup_steps:
raise ValueError("total_steps must be greater than warmup_steps.")
return self


class FSDP1CheckpointedOptimizerConfig(BaseModel):
checkpoint_loading: PydanticFSDP1CheckpointLoadingIFType
checkpoint_path: Path
Expand Down
55 changes: 49 additions & 6 deletions src/modalities/optimizers/lr_schedulers.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,64 @@
import warnings
from typing import Optional

from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, LRScheduler, SequentialLR


class DummyLRScheduler(LRScheduler):
def __init__(self, optimizer: Optimizer, last_epoch: Optional[int] = -1):
def __init__(self, optimizer: Optimizer, last_epoch: int = -1):
super().__init__(optimizer, last_epoch)

def get_lr(self) -> list[float]:
def get_lr(self) -> list[float | Tensor]:
if not self._get_lr_called_within_step: # type error expected due to internal pytorch implementation
warnings.warn(
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning
"To get the last learning rate computed by the scheduler, please use `get_last_lr()`.",
UserWarning,
)

return [group["lr"] for group in self.optimizer.param_groups]

def _get_closed_form_lr(self) -> list[float]:
def _get_closed_form_lr(self) -> list[float | Tensor]:
return self.base_lrs


class LRSchedulerFactory:
@staticmethod
def get_linear_warmup_cosine_annealing_lr_scheduler(
optimizer: Optimizer,
warmup_steps: int,
total_steps: int,
initial_lr: float,
final_lr: float,
max_lr: float,
last_epoch: int = -1,
) -> SequentialLR:
if warmup_steps <= 0:
raise ValueError("warmup_steps must be greater than 0.")
if total_steps <= warmup_steps:
raise ValueError("total_steps must be greater than warmup_steps.")

if not all(base_lr == max_lr for base_lr in [group["lr"] for group in optimizer.param_groups]):
raise ValueError(
"All parameter groups must have the same initial_lr."
"and it must be equal to the initial_lr passed to the LR scheduler factory."
)

warmup_scheduler = LinearLR(
optimizer=optimizer,
start_factor=initial_lr / max_lr,
end_factor=1,
total_iters=warmup_steps,
)
cosine_scheduler = CosineAnnealingLR(
optimizer=optimizer,
T_max=total_steps - warmup_steps,
eta_min=final_lr,
)

return SequentialLR(
optimizer=optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[warmup_steps],
last_epoch=last_epoch,
)
9 changes: 8 additions & 1 deletion src/modalities/registry/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
GPT2MFUCalculatorConfig,
GPT2ModelTPConfig,
LinearLRSchedulerConfig,
LinearWarmupCosineAnnealingLRSchedulerConfig,
LLMDataLoaderConfig,
MemMapDatasetConfig,
OneCycleLRSchedulerConfig,
Expand Down Expand Up @@ -108,7 +109,7 @@
ComposedInitializationRoutines,
ComposedModelInitializationConfig,
)
from modalities.optimizers.lr_schedulers import DummyLRScheduler
from modalities.optimizers.lr_schedulers import DummyLRScheduler, LRSchedulerFactory
from modalities.optimizers.optimizer_factory import OptimizerFactory
from modalities.optimizers.optimizer_list import OptimizersList
from modalities.optimizers.scheduler_list import SchedulerList
Expand Down Expand Up @@ -285,6 +286,12 @@ class ComponentEntity:
maybe_optimizer_list(torch.optim.lr_scheduler.CosineAnnealingLR),
CosineAnnealingLRSchedulerConfig,
),
ComponentEntity(
"scheduler",
"linear_warmup_cosine_annealing_lr",
maybe_optimizer_list(LRSchedulerFactory.get_linear_warmup_cosine_annealing_lr_scheduler),
LinearWarmupCosineAnnealingLRSchedulerConfig,
),
# tokenizers
ComponentEntity("tokenizer", "pretrained_hf_tokenizer", PreTrainedHFTokenizer, PreTrainedHFTokenizerConfig),
ComponentEntity("tokenizer", "pretrained_sp_tokenizer", PreTrainedSPTokenizer, PreTrainedSPTokenizerConfig),
Expand Down
24 changes: 24 additions & 0 deletions src/modalities/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime
from enum import Enum
import gc
from typing import Callable, Optional

import torch
Expand All @@ -26,6 +27,25 @@
from modalities.utils.typing_utils import FSDPX


class GarbageCollection:
# Some portions of this implementation are inspired, adapted, or refactored
# from Meta's open-source project TorchTitan,
# licensed under the BSD 3-Clause License.
def __init__(self, gc_freq: int = 10):
assert gc_freq > 0, "gc_freq must be a positive integer"
self.gc_freq = gc_freq
gc.disable()
self.collect() # GC invoked here

def run(self, step_count: int):
if step_count > 1 and step_count % self.gc_freq == 0:
self.collect() # GC invoked here

@staticmethod
def collect(generation: int = 1):
gc.collect(generation) # GC invoked here


class ThroughputAggregationKeys(Enum):
NUM_SAMPLES = "NUM_SAMPLES"
FORWARD_BACKWARD_TIME = "FORWARD_BACKWARD_TIME"
Expand Down Expand Up @@ -70,6 +90,7 @@ def __init__(
Returns:
None
"""
self.gc = GarbageCollection(gc_freq=10)
self.global_rank = global_rank
if device_mesh is not None:
self.dp_degree = get_parallel_degree(
Expand Down Expand Up @@ -283,6 +304,7 @@ def train(
)
# Check if model performance should be logged
if training_progress.num_seen_steps_total % training_log_interval_in_steps == 0 and step_performed:
dist.barrier()
forward_backward_time_recorder.stop()
forward_backward_time = forward_backward_time_recorder.delta_t
forward_backward_time_recorder.reset()
Expand Down Expand Up @@ -363,8 +385,10 @@ def train(

cumulated_losses.zero_()
if step_performed:
self.gc.run(step_count=training_progress.num_seen_steps_total)
evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total)
checkpointing_callback(training_progress=training_progress)

profiler_cm.step()

@staticmethod
Expand Down
27 changes: 26 additions & 1 deletion tests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from unittest.mock import MagicMock, call

import numpy as np
import torch

from modalities.checkpointing.checkpoint_saving import CheckpointSaving
from modalities.checkpointing.stateful.app_state import AppState
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.evaluator import Evaluator
from modalities.gym import Gym
from modalities.loss_functions import Loss
from modalities.optimizers.lr_schedulers import DummyLRScheduler
from modalities.optimizers.lr_schedulers import DummyLRScheduler, LRSchedulerFactory
from modalities.trainer import Trainer
from tests.utility import configure_dataloader_mock

Expand Down Expand Up @@ -76,3 +77,27 @@ def test_dummy_lr_scheduler(optimizer_with_param_groups_mock: MagicMock):
assert np.allclose(scheduler.get_lr(), [0.08, 0.18, 0.28], atol=1e-6)
assert scheduler._get_closed_form_lr() == [0.1, 0.2, 0.3]
assert np.allclose(scheduler.get_last_lr(), [0.08, 0.18, 0.28], atol=1e-6)


def test_linear_warmup_cosine_annealing_lr_scheduler():
parameter = torch.nn.Parameter(torch.tensor([1.0]))
optimizer = torch.optim.SGD([parameter], lr=1.0)
scheduler = LRSchedulerFactory.get_linear_warmup_cosine_annealing_lr_scheduler(
optimizer=optimizer,
warmup_steps=2,
total_steps=6,
initial_lr=0.1,
final_lr=0.2,
max_lr=1.0,
)

learning_rates = [scheduler.get_last_lr()[0]]
for _ in range(6):
optimizer.step()
scheduler.step()
learning_rates.append(scheduler.get_last_lr()[0])

assert learning_rates[0] < learning_rates[1] < learning_rates[2]
assert np.isclose(learning_rates[2], 1.0, atol=1e-6)
assert learning_rates[2] > learning_rates[3] > learning_rates[4] > learning_rates[5] > learning_rates[6]
assert np.isclose(learning_rates[6], 0.2, atol=1e-6)
Loading