forked from nathan-barry/tiny-diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
443 lines (373 loc) · 15.8 KB
/
model.py
File metadata and controls
443 lines (373 loc) · 15.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
# Heavily modified version of nanochat gpt.py to do diffusion
# https://github.com/karpathy/nanochat/blob/master/nanochat/gpt.py
#
# Config is based on hyperparameters from Karpathy's "Let's build GPT" video
# https://github.com/karpathy/ng-video-lecture/blob/master/gpt.py
#
# Tokenizer is simple ascii mapping
"""
Simple Character-Level Discrete Diffusion Transformer
Major changes from nanochat/gpt.py:
- Bidirectional attention instead of Causal (no kvcache)
- Time step conditioning added (time embeddings)
- Replace autoregressive generation with topk and confidence-aware parallel decoding
- Removed MQA/GQA (n_kv_head), simplified to standard multi-head attention
- Removed optimizer setup, FLOPs estimation, and embedding dtype casting
"""
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class DiffusionConfig:
sequence_len: int = 256
vocab_size: int = 128 # Full ASCII (0-127), where 0 is reserved for mask
mask_token_id: int = 0 # NUL character used as [MASK] token
n_layer: int = 6
n_head: int = 6
n_embd: int = 384
diffusion_steps: int = 128
context_len: int = 16 # Number of prefix tokens that are never masked
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
y2 = x1 * (-sin) + x2 * cos
out = torch.cat([y1, y2], 3) # re-assemble
out = out.to(x.dtype) # ensure input/output dtypes match
return out
class BidirectionalAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin):
B, T, C = x.size()
# Project the input to get queries, keys, and values
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
q, k, v = (
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
) # (B, T, H, D) -> (B, H, T, D)
# Bidirectional attention - no causal masking
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
# Re-assemble the heads and project back
y = y.transpose(1, 2).contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = BidirectionalAttention(config)
self.mlp = MLP(config)
def forward(self, x, cos_sin):
x = x + self.attn(norm(x), cos_sin)
x = x + self.mlp(norm(x))
return x
class DiffusionTransformer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Token and time embeddings (include mask token in vocab)
self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.time_emb = nn.Embedding(config.diffusion_steps, config.n_embd)
# Transformer blocks
self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
# Output head to predict denoised tokens
self.output_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Rotary embeddings
self.rotary_seq_len = config.sequence_len * 2
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
def init_weights(self):
self.apply(self._init_weights)
# Zero out output head weights
torch.nn.init.zeros_(self.output_head.weight)
# Zero out c_proj weights in all blocks
for block in self.blocks:
torch.nn.init.zeros_(block.mlp.c_proj.weight)
torch.nn.init.zeros_(block.attn.c_proj.weight)
# Init the rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
def _init_weights(self, module):
if isinstance(module, nn.Linear):
# https://arxiv.org/pdf/2310.17813
fan_out = module.weight.size(0)
fan_in = module.weight.size(1)
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
if device is None:
device = self.token_emb.weight.device
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = (
cos[None, :, None, :],
sin[None, :, None, :],
) # add batch and head dims
return cos, sin
def get_device(self):
return self.token_emb.weight.device
def forward(self, x_t, t):
"""
Forward pass for diffusion model
Args:
x_t: Noisy tokens at timestep t, shape (B, T)
t: Timestep indices, shape (B,)
Returns:
logits: Predicted token logits, shape (B, T, vocab_size)
"""
B, T = x_t.size()
# Get embeddings
x = self.token_emb(x_t) # (B, T, n_embd)
t_emb = self.time_emb(t) # (B, n_embd)
# Add time embedding to all positions
x = x + t_emb.unsqueeze(1) # broadcast time embedding across sequence
x = norm(x)
# Get rotary embeddings
assert T <= self.cos.size(1)
cos_sin = (self.cos[:, :T], self.sin[:, :T])
# Forward through transformer blocks
for block in self.blocks:
x = block(x, cos_sin)
x = norm(x)
# Predict denoised tokens
logits = self.output_head(x) # (B, T, vocab_size)
return logits
@torch.inference_mode()
def sample_topk(
self,
batch_size,
seq_len,
k,
num_steps=None,
temperature=1.0,
device=None,
context_tokens=None,
):
"""
Generate samples using top-K parallel decoding (LLaDA baseline).
At each step, decode exactly K tokens with highest confidence.
Args:
batch_size: Number of samples to generate
seq_len: Length of sequences to generate
k: Number of tokens to decode per step
num_steps: Maximum number of denoising steps
temperature: Sampling temperature
device: Device to generate on
context_tokens: Optional context tokens for conditioning, shape (batch_size, context_len)
Returns:
samples: Generated token sequences, shape (batch_size, seq_len)
"""
if device is None:
device = self.get_device()
if num_steps is None:
num_steps = seq_len # Maximum possible steps
# Start from all mask tokens
x = torch.full(
(batch_size, seq_len),
self.config.mask_token_id,
dtype=torch.long,
device=device,
)
# If context tokens provided, set them in the first context_len positions
if context_tokens is not None:
context_len = context_tokens.size(1)
x[:, :context_len] = context_tokens.to(device)
# Track which positions are still masked
masked_positions = torch.ones(
batch_size, seq_len, dtype=torch.bool, device=device
)
if context_tokens is not None:
masked_positions[:, :context_len] = False
# Decode step by step
for step in range(num_steps):
# Check if all tokens are decoded
if not masked_positions.any():
break
# Create timestep (use step as proxy for timestep)
t_batch = torch.full((batch_size,), step, device=device, dtype=torch.long)
t_batch = torch.clamp(t_batch, 0, self.config.diffusion_steps - 1)
# Predict tokens
logits = self.forward(x, t_batch)
# Get confidence scores (max probability for each position)
probs = F.softmax(logits / temperature, dim=-1)
confidences, predicted_tokens = torch.max(probs, dim=-1) # (B, T)
# Mask out already-decoded positions
confidences = confidences.masked_fill(~masked_positions, -float("inf"))
# Select top-K positions per batch
k_actual = min(k, masked_positions.sum(dim=1).max().item())
_, topk_indices = torch.topk(confidences, k=k_actual, dim=1) # (B, K)
# Update the top-K positions
for b in range(batch_size):
for idx in topk_indices[b]:
if masked_positions[b, idx]:
x[b, idx] = predicted_tokens[b, idx]
masked_positions[b, idx] = False
return x
@torch.inference_mode()
def sample_confidence(
self,
batch_size,
seq_len,
confidence_threshold=0.95,
num_steps=None,
temperature=1.0,
device=None,
context_tokens=None,
):
"""
Generate samples using confidence-aware parallel decoding (Fast-dLLM).
At each step, decode all tokens whose confidence exceeds a threshold.
Args:
batch_size: Number of samples to generate
seq_len: Length of sequences to generate
confidence_threshold: Threshold τ for token acceptance
num_steps: Maximum number of denoising steps
temperature: Sampling temperature
device: Device to generate on
context_tokens: Optional context tokens for conditioning, shape (batch_size, context_len)
Returns:
samples: Generated token sequences, shape (batch_size, seq_len)
"""
if device is None:
device = self.get_device()
if num_steps is None:
num_steps = seq_len # Maximum possible steps
# Start from all mask tokens
x = torch.full(
(batch_size, seq_len),
self.config.mask_token_id,
dtype=torch.long,
device=device,
)
# If context tokens provided, set them in the first context_len positions
if context_tokens is not None:
context_len = context_tokens.size(1)
x[:, :context_len] = context_tokens.to(device)
# Track which positions are still masked
masked_positions = torch.ones(
batch_size, seq_len, dtype=torch.bool, device=device
)
if context_tokens is not None:
masked_positions[:, :context_len] = False
# Decode step by step
for step in range(num_steps):
# Check if all tokens are decoded
if not masked_positions.any():
break
# Create timestep (use step as proxy for timestep)
t_batch = torch.full((batch_size,), step, device=device, dtype=torch.long)
t_batch = torch.clamp(t_batch, 0, self.config.diffusion_steps - 1)
# Predict tokens
logits = self.forward(x, t_batch)
# Get confidence scores (max probability for each position)
probs = F.softmax(logits / temperature, dim=-1)
confidences, predicted_tokens = torch.max(probs, dim=-1) # (B, T)
# Select positions above threshold (only among masked positions)
above_threshold = (confidences >= confidence_threshold) & masked_positions
# Ensure at least one token is decoded per batch if any remain masked
for b in range(batch_size):
if masked_positions[b].any() and not above_threshold[b].any():
# Decode the highest confidence masked token
masked_confidences = confidences[b].clone()
masked_confidences[~masked_positions[b]] = -float("inf")
best_idx = torch.argmax(masked_confidences)
above_threshold[b, best_idx] = True
# Update positions above threshold
x = torch.where(above_threshold, predicted_tokens, x)
masked_positions = masked_positions & ~above_threshold
return x
@torch.inference_mode()
def sample(
self,
batch_size,
seq_len,
num_steps=None,
temperature=1.0,
device=None,
context_tokens=None,
method="confidence",
k=None,
confidence_threshold=0.95,
):
"""
Generate samples using parallel decoding methods.
Args:
batch_size: Number of samples to generate
seq_len: Length of sequences to generate
num_steps: Maximum number of denoising steps
temperature: Sampling temperature
device: Device to generate on
context_tokens: Optional context tokens for conditioning, shape (batch_size, context_len)
method: Decoding method - 'topk' or 'confidence'
k: Number of tokens per step (for 'topk' method)
confidence_threshold: Confidence threshold τ (for 'confidence' method)
Returns:
samples: Generated token sequences, shape (batch_size, seq_len)
"""
if method == "topk":
if k is None:
k = max(1, seq_len // 10) # Default: decode 10% per step
return self.sample_topk(
batch_size, seq_len, k, num_steps, temperature, device, context_tokens
)
elif method == "confidence":
return self.sample_confidence(
batch_size,
seq_len,
confidence_threshold,
num_steps,
temperature,
device,
context_tokens,
)
else:
raise ValueError(f"Unknown sampling method: {method}")
def encode_text(text):
"""Convert text to vocab indices using direct ASCII mapping"""
tokens = torch.tensor([min(ord(c), 127) for c in text], dtype=torch.long)
return tokens
def decode_tokens(tokens):
"""Convert vocab indices to text using direct ASCII mapping"""
text = "".join([chr(int(t)) for t in tokens])
return text