Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 169 additions & 74 deletions usage_examples/trainers/finetuner.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# limitations under the License.

# This module does not embed third-party data download URLs.
import contextlib
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from tqdm import tqdm
Expand All @@ -22,6 +25,44 @@
from usage_examples.trainers.model_wrapper import GFMWithProjection


def _is_cuda_device(device) -> bool:
"""
Accept str ("cuda", "cuda:0"), torch.device, or bare int (treated as a CUDA
ordinal, matching torch.cuda.set_device semantics).
"""
if isinstance(device, int):
return True
if isinstance(device, torch.device):
return device.type == "cuda"
return str(device).startswith("cuda")


def _amp_available(device) -> bool:
"""
AMP is usable when the device is CUDA, CUDA is actually present, and
torch.amp exposes autocast. Returns False on CPU/MPS.
"""
if not _is_cuda_device(device):
return False
if not torch.cuda.is_available():
return False
return hasattr(torch, "amp") and hasattr(torch.amp, "autocast")


def _pick_amp_dtype() -> torch.dtype:
"""
Prefer bfloat16 on hardware that supports it (Ampere+ / Hopper / etc.):
bf16 keeps fp32 dynamic range and needs no loss scaling. Fall back to
fp16 + GradScaler on older GPUs.
"""
try:
if torch.cuda.is_bf16_supported():
return torch.bfloat16
except Exception:
pass
return torch.float16


class GFMFinetuner:
"""
Fine-tuning module for genomic foundation models.
Expand All @@ -41,7 +82,9 @@ def __init__(
only_proj_layer=True,
is_variant_effect_prediction=False,
disable_cache=False,
device='cpu'
device='cpu',
use_amp=None,
amp_dtype=None,
):
"""
Initialize the fine-tuner.
Expand All @@ -59,6 +102,12 @@ def __init__(
is_variant_effect_prediction: if True, task uses variant/ref sequence pairs
disable_cache: if True, skip frozen backbone forward cache during linear probing
device: torch device
use_amp: enable CUDA mixed-precision for the full fine-tune path. None ->
auto: on iff CUDA is available and torch.amp is installed. Pass False
to force off. Has no effect on the linear-probe path, which stays fp32
so the cached backbone output remains numpy-serializable.
amp_dtype: autocast dtype. None -> auto: bfloat16 when supported (no loss
scaling needed), else float16 (with GradScaler).
"""
self.model = model
self.train_loader = train_loader
Expand All @@ -72,7 +121,24 @@ def __init__(
self.is_variant_effect_prediction = is_variant_effect_prediction
self.disable_cache = disable_cache
self.device = device


if use_amp is None:
use_amp = _amp_available(device)
elif use_amp and not _amp_available(device):
print("Warning: use_amp=True but CUDA/torch.amp unavailable; disabling AMP.")
use_amp = False
self.use_amp = bool(use_amp)
self.amp_dtype = amp_dtype if amp_dtype is not None else _pick_amp_dtype()

if self.use_amp:
if self.only_proj_layer:
print(
"AMP requested but not applied on the cached linear-probe path "
"(backbone is cached in fp32 for numpy-safe storage)."
)
else:
print(f"AMP enabled for full fine-tuning (dtype={self.amp_dtype}).")

# Create projection layer for classification
# For variant effect tasks, input is concatenated embeddings (hidden_dim * 2)
proj_input_dim = self.hidden_dim * 2 if is_variant_effect_prediction else self.hidden_dim
Expand Down Expand Up @@ -118,7 +184,21 @@ def fine_tune(self):

# Loss function for classification
criterion = torch.nn.CrossEntropyLoss()


# AMP is applied only to the full fine-tune path. On the linear-probe
# path the backbone output is cached as CPU numpy, and numpy has no
# bfloat16 dtype, so autocasting the cached forward would crash on store.
# The cache already eliminates the backbone cost across epochs there, so
# keeping that path in fp32 costs effectively nothing.
amp_active = self.use_amp and not self.only_proj_layer
needs_scaler = amp_active and self.amp_dtype == torch.float16
scaler = torch.amp.GradScaler("cuda", enabled=needs_scaler) if amp_active else None

def _autocast():
if not amp_active:
return contextlib.nullcontext()
return torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype)

# Set to training mode
if not self.only_proj_layer:
self.model.train()
Expand All @@ -128,78 +208,93 @@ def fine_tune(self):

fwd_cache = SequenceInferenceCache() if self.only_proj_layer else None

# Training loop
for epoch in range(self.num_epochs):
total_loss = 0.0
num_batches = 0

progress_bar = tqdm(self.train_loader, desc=f"Fine-tuning epoch {epoch+1}/{self.num_epochs}")
for batch in progress_bar:
if self.is_variant_effect_prediction:
# Variant effect task: (variant_seqs, ref_seqs, labels, conditional_input)
variant_sequences, ref_sequences, labels, conditional_input = batch
labels = labels.to(self.device)

# Get representative embeddings for both sequences
# Use no_grad when only training projection layer (saves memory/compute)
if self.only_proj_layer:
with torch.no_grad():
var_repr = fwd_cache.cached_call(
self.model._sequence_to_representative,
variant_sequences,
disable=self.disable_cache,
)
ref_repr = fwd_cache.cached_call(
self.model._sequence_to_representative,
ref_sequences,
disable=self.disable_cache,
)
# Detach to ensure no gradient flow to model
var_repr = var_repr.detach()
ref_repr = ref_repr.detach()
else:
var_repr = self.model._sequence_to_representative(variant_sequences)
ref_repr = self.model._sequence_to_representative(ref_sequences)

# Concatenate variant and reference representations
sequence_repr = torch.cat([var_repr, ref_repr], dim=1)
else:
# Single sequence task: (sequences, labels, conditional_input)
sequences, labels, conditional_input = batch
labels = labels.to(self.device)

# Get representative embeddings
# Use no_grad when only training projection layer (saves memory/compute)
if self.only_proj_layer:
with torch.no_grad():
sequence_repr = fwd_cache.cached_call(
self.model._sequence_to_representative,
sequences,
disable=self.disable_cache,
)
# Detach to ensure no gradient flow to model
sequence_repr = sequence_repr.detach()
# When only the projection layer is trained, freeze backbone params for
# the duration of training. Inputs (token ids) never require grad, so with
# the backbone frozen its forward yields requires_grad=False outputs
# naturally -- no autograd graph is built and the per-batch no_grad()/
# detach() dance is unnecessary. Cache hits return fresh tensors rebuilt
# from numpy (also grad-free). State is restored in the finally block so
# the caller's model is not permanently mutated.
original_requires_grad: Optional[List[Tuple[torch.nn.Parameter, bool]]] = None
if self.only_proj_layer:
original_requires_grad = [
(p, p.requires_grad) for p in self.model.parameters()
]
for p, _ in original_requires_grad:
p.requires_grad_(False)

avg_loss = float("nan")
try:
# Training loop
for epoch in range(self.num_epochs):
total_loss = 0.0
num_batches = 0

progress_bar = tqdm(self.train_loader, desc=f"Fine-tuning epoch {epoch+1}/{self.num_epochs}")
for batch in progress_bar:
with _autocast():
if self.is_variant_effect_prediction:
# Variant effect task: (variant_seqs, ref_seqs, labels, conditional_input)
variant_sequences, ref_sequences, labels, conditional_input = batch
labels = labels.to(self.device)

if self.only_proj_layer:
var_repr = fwd_cache.cached_call(
self.model._sequence_to_representative,
variant_sequences,
disable=self.disable_cache,
)
ref_repr = fwd_cache.cached_call(
self.model._sequence_to_representative,
ref_sequences,
disable=self.disable_cache,
)
else:
var_repr = self.model._sequence_to_representative(variant_sequences)
ref_repr = self.model._sequence_to_representative(ref_sequences)

# Concatenate variant and reference representations
sequence_repr = torch.cat([var_repr, ref_repr], dim=1)
else:
# Single sequence task: (sequences, labels, conditional_input)
sequences, labels, conditional_input = batch
labels = labels.to(self.device)

if self.only_proj_layer:
sequence_repr = fwd_cache.cached_call(
self.model._sequence_to_representative,
sequences,
disable=self.disable_cache,
)
else:
sequence_repr = self.model._sequence_to_representative(sequences)

logits = self.projection(sequence_repr)
loss = criterion(logits, labels)

optimizer.zero_grad(set_to_none=True)
if needs_scaler:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
sequence_repr = self.model._sequence_to_representative(sequences)

logits = self.projection(sequence_repr)

loss = criterion(logits, labels)

optimizer.zero_grad()
loss.backward()
optimizer.step()

# Update loss tracking
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches

# Update progress bar with average loss
progress_bar.set_postfix({'avg_loss': f'{avg_loss:.4f}'})

if fwd_cache is not None:
fwd_cache.clear()
loss.backward()
optimizer.step()

# Update loss tracking
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches

# Update progress bar with average loss
progress_bar.set_postfix({'avg_loss': f'{avg_loss:.4f}'})
finally:
if original_requires_grad is not None:
for p, orig in original_requires_grad:
p.requires_grad_(orig)

if fwd_cache is not None:
fwd_cache.clear()

if self.num_epochs > 0:
print(f"Fine-tuning completed. Final average loss: {avg_loss:.4f}")
Expand Down