-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfsdp_train.py
More file actions
262 lines (205 loc) · 8.8 KB
/
fsdp_train.py
File metadata and controls
262 lines (205 loc) · 8.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""
train/fsdp_train.py
Distributed fine-tuning of GPT-2 using PyTorch FSDP (Fully Sharded Data Parallel)
with BF16 mixed precision training.
Key design decisions:
- FSDP shards optimizer states, gradients, AND parameters across ranks
(vs DDP which only shards gradients) — critical for scaling to large models
where optimizer states alone exceed single-GPU memory
- BF16 preferred over FP16 for training stability: same memory savings
but larger dynamic range prevents loss spikes on transformer models
- CPU offloading disabled by default — enabled via --offload flag for
larger models that exceed GPU memory even with sharding
Usage:
torchrun --nproc_per_node=4 train/fsdp_train.py \
--model gpt2 \
--dataset wikitext \
--epochs 3 \
--output_dir checkpoints/
References:
- PyTorch FSDP: https://pytorch.org/docs/stable/fsdp.html
- BF16 training stability: Kalamkar et al. (2019)
"""
import argparse
import os
import time
from contextlib import contextmanager
from functools import partial
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
CPUOffload,
StateDictType,
FullStateDictConfig,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
size_based_auto_wrap_policy,
)
from transformers import (
GPT2LMHeadModel,
GPT2Config,
GPT2Tokenizer,
DataCollatorForLanguageModeling,
)
from torch.utils.data import DataLoader, DistributedSampler
from datasets import load_dataset
# ── BF16 mixed precision policy ───────────────────────────────────────────────
# param_dtype=BF16: forward pass and activations in BF16
# reduce_dtype=FP32: gradient all-reduce in FP32 for numerical stability
# buffer_dtype=BF16: layer norms, embeddings in BF16
BF16_POLICY = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
)
def setup_distributed():
"""Initialise process group for distributed training."""
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
return rank, world_size
def cleanup_distributed():
dist.destroy_process_group()
def get_model_and_tokenizer(model_name: str, rank: int):
"""
Load GPT-2 on CPU first (rank 0 only), then let FSDP shard across GPUs.
Loading on CPU prevents OOM on large models before sharding.
"""
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
if rank == 0:
print(f"Loading {model_name} on CPU for FSDP sharding...")
# Load on CPU — FSDP will move shards to each GPU
model = GPT2LMHeadModel.from_pretrained(model_name, torch_dtype=torch.bfloat16)
return model, tokenizer
def wrap_model_fsdp(model, args, rank):
"""
Wrap model with FSDP.
ShardingStrategy.FULL_SHARD:
Shards params + gradients + optimizer states across all ranks.
Most memory-efficient — enables models 4x larger than single-GPU memory.
Trade-off: all-gather overhead on forward + backward pass.
transformer_auto_wrap_policy:
Wraps each transformer block (GPT2Block) as a separate FSDP unit.
Critical: ensures each block's params are gathered and freed independently,
preventing the entire model from being reconstructed on any single GPU.
"""
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={GPT2Block},
)
cpu_offload = CPUOffload(offload_params=args.cpu_offload)
fsdp_model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=BF16_POLICY,
sharding_strategy=ShardingStrategy.FULL_SHARD,
cpu_offload=cpu_offload,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
device_id=torch.cuda.current_device(),
limit_all_gathers=True, # Prevents memory spikes from concurrent all-gathers
)
if rank == 0:
print(f"Model wrapped with FSDP — ShardingStrategy: FULL_SHARD, Mixed Precision: BF16")
return fsdp_model
def get_dataloader(tokenizer, split: str, rank: int, world_size: int,
batch_size: int, seq_len: int = 512):
"""WikiText-103 dataloader with DistributedSampler for FSDP."""
dataset = load_dataset("wikitext", "wikitext-103-v1", split=split)
def tokenize(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=seq_len,
padding="max_length",
return_tensors="pt",
)
tokenized = dataset.map(tokenize, batched=True, remove_columns=["text"])
tokenized.set_format("torch")
sampler = DistributedSampler(tokenized, num_replicas=world_size,
rank=rank, shuffle=(split == "train"))
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
return DataLoader(
tokenized,
batch_size=batch_size,
sampler=sampler,
collate_fn=collator,
pin_memory=True,
num_workers=4,
)
def train_epoch(model, loader, optimizer, scheduler, rank, epoch):
model.train()
total_loss = 0.0
t0 = time.time()
for step, batch in enumerate(loader):
input_ids = batch["input_ids"].cuda()
labels = batch["labels"].cuda()
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
# Gradient clipping — unscale before clip when using mixed precision
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
if rank == 0 and step % 50 == 0:
elapsed = time.time() - t0
print(f" Epoch {epoch} | Step {step} | Loss {loss.item():.4f} | "
f"LR {scheduler.get_last_lr()[0]:.2e} | {elapsed:.1f}s")
return total_loss / len(loader)
def save_checkpoint(model, optimizer, rank, epoch, output_dir):
"""
Save FSDP checkpoint using FULL_STATE_DICT — gathers all shards to rank 0.
Use SHARDED_STATE_DICT for large models that don't fit on a single GPU.
"""
os.makedirs(output_dir, exist_ok=True)
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
state_dict = model.state_dict()
if rank == 0:
path = os.path.join(output_dir, f"checkpoint_epoch_{epoch}.pt")
torch.save({"epoch": epoch, "model_state_dict": state_dict,
"optimizer_state_dict": optimizer.state_dict()}, path)
print(f"Checkpoint saved: {path}")
def main(args):
rank, world_size = setup_distributed()
model, tokenizer = get_model_and_tokenizer(args.model, rank)
model = wrap_model_fsdp(model, args, rank)
train_loader = get_dataloader(tokenizer, "train", rank, world_size,
args.batch_size, args.seq_len)
val_loader = get_dataloader(tokenizer, "validation", rank, world_size,
args.batch_size, args.seq_len)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr,
weight_decay=0.01, fused=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs * len(train_loader))
if rank == 0:
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel: {args.model} | Params: {total_params/1e6:.1f}M | "
f"World size: {world_size} | BF16: True\n")
for epoch in range(1, args.epochs + 1):
train_loader.sampler.set_epoch(epoch)
train_loss = train_epoch(model, train_loader, optimizer, scheduler, rank, epoch)
if rank == 0:
print(f"Epoch {epoch} complete | Train loss: {train_loss:.4f}")
save_checkpoint(model, optimizer, rank, epoch, args.output_dir)
cleanup_distributed()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="gpt2")
parser.add_argument("--dataset", default="wikitext")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--seq_len", type=int, default=512)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--output_dir", default="checkpoints/")
parser.add_argument("--cpu_offload", action="store_true")
main(parser.parse_args())