From f4e157768077584f15460181a4703873ddb9aa76 Mon Sep 17 00:00:00 2001 From: Raizy Kellerman Date: Thu, 11 Jun 2026 15:34:47 +0300 Subject: [PATCH] Performance opt: 1. AMP for the full fine-tune path, 2. zero_grad(set_to_none=True) 3. Backbone freeze when only projection head is trained; + robustness changes - avg_loss initialized to nan and fwd_cache.clear() moved to finally block to run even in case of exceptions, --- usage_examples/trainers/finetuner.py | 243 +++++++++++++++++++-------- 1 file changed, 169 insertions(+), 74 deletions(-) mode change 100644 => 100755 usage_examples/trainers/finetuner.py diff --git a/usage_examples/trainers/finetuner.py b/usage_examples/trainers/finetuner.py old mode 100644 new mode 100755 index dc2ae8d..0b1c391 --- a/usage_examples/trainers/finetuner.py +++ b/usage_examples/trainers/finetuner.py @@ -14,6 +14,9 @@ # limitations under the License. # This module does not embed third-party data download URLs. +import contextlib +from typing import List, Optional, Tuple + import torch import torch.nn as nn from tqdm import tqdm @@ -22,6 +25,44 @@ from usage_examples.trainers.model_wrapper import GFMWithProjection +def _is_cuda_device(device) -> bool: + """ + Accept str ("cuda", "cuda:0"), torch.device, or bare int (treated as a CUDA + ordinal, matching torch.cuda.set_device semantics). + """ + if isinstance(device, int): + return True + if isinstance(device, torch.device): + return device.type == "cuda" + return str(device).startswith("cuda") + + +def _amp_available(device) -> bool: + """ + AMP is usable when the device is CUDA, CUDA is actually present, and + torch.amp exposes autocast. Returns False on CPU/MPS. + """ + if not _is_cuda_device(device): + return False + if not torch.cuda.is_available(): + return False + return hasattr(torch, "amp") and hasattr(torch.amp, "autocast") + + +def _pick_amp_dtype() -> torch.dtype: + """ + Prefer bfloat16 on hardware that supports it (Ampere+ / Hopper / etc.): + bf16 keeps fp32 dynamic range and needs no loss scaling. Fall back to + fp16 + GradScaler on older GPUs. + """ + try: + if torch.cuda.is_bf16_supported(): + return torch.bfloat16 + except Exception: + pass + return torch.float16 + + class GFMFinetuner: """ Fine-tuning module for genomic foundation models. @@ -41,7 +82,9 @@ def __init__( only_proj_layer=True, is_variant_effect_prediction=False, disable_cache=False, - device='cpu' + device='cpu', + use_amp=None, + amp_dtype=None, ): """ Initialize the fine-tuner. @@ -59,6 +102,12 @@ def __init__( is_variant_effect_prediction: if True, task uses variant/ref sequence pairs disable_cache: if True, skip frozen backbone forward cache during linear probing device: torch device + use_amp: enable CUDA mixed-precision for the full fine-tune path. None -> + auto: on iff CUDA is available and torch.amp is installed. Pass False + to force off. Has no effect on the linear-probe path, which stays fp32 + so the cached backbone output remains numpy-serializable. + amp_dtype: autocast dtype. None -> auto: bfloat16 when supported (no loss + scaling needed), else float16 (with GradScaler). """ self.model = model self.train_loader = train_loader @@ -72,7 +121,24 @@ def __init__( self.is_variant_effect_prediction = is_variant_effect_prediction self.disable_cache = disable_cache self.device = device - + + if use_amp is None: + use_amp = _amp_available(device) + elif use_amp and not _amp_available(device): + print("Warning: use_amp=True but CUDA/torch.amp unavailable; disabling AMP.") + use_amp = False + self.use_amp = bool(use_amp) + self.amp_dtype = amp_dtype if amp_dtype is not None else _pick_amp_dtype() + + if self.use_amp: + if self.only_proj_layer: + print( + "AMP requested but not applied on the cached linear-probe path " + "(backbone is cached in fp32 for numpy-safe storage)." + ) + else: + print(f"AMP enabled for full fine-tuning (dtype={self.amp_dtype}).") + # Create projection layer for classification # For variant effect tasks, input is concatenated embeddings (hidden_dim * 2) proj_input_dim = self.hidden_dim * 2 if is_variant_effect_prediction else self.hidden_dim @@ -118,7 +184,21 @@ def fine_tune(self): # Loss function for classification criterion = torch.nn.CrossEntropyLoss() - + + # AMP is applied only to the full fine-tune path. On the linear-probe + # path the backbone output is cached as CPU numpy, and numpy has no + # bfloat16 dtype, so autocasting the cached forward would crash on store. + # The cache already eliminates the backbone cost across epochs there, so + # keeping that path in fp32 costs effectively nothing. + amp_active = self.use_amp and not self.only_proj_layer + needs_scaler = amp_active and self.amp_dtype == torch.float16 + scaler = torch.amp.GradScaler("cuda", enabled=needs_scaler) if amp_active else None + + def _autocast(): + if not amp_active: + return contextlib.nullcontext() + return torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype) + # Set to training mode if not self.only_proj_layer: self.model.train() @@ -128,78 +208,93 @@ def fine_tune(self): fwd_cache = SequenceInferenceCache() if self.only_proj_layer else None - # Training loop - for epoch in range(self.num_epochs): - total_loss = 0.0 - num_batches = 0 - - progress_bar = tqdm(self.train_loader, desc=f"Fine-tuning epoch {epoch+1}/{self.num_epochs}") - for batch in progress_bar: - if self.is_variant_effect_prediction: - # Variant effect task: (variant_seqs, ref_seqs, labels, conditional_input) - variant_sequences, ref_sequences, labels, conditional_input = batch - labels = labels.to(self.device) - - # Get representative embeddings for both sequences - # Use no_grad when only training projection layer (saves memory/compute) - if self.only_proj_layer: - with torch.no_grad(): - var_repr = fwd_cache.cached_call( - self.model._sequence_to_representative, - variant_sequences, - disable=self.disable_cache, - ) - ref_repr = fwd_cache.cached_call( - self.model._sequence_to_representative, - ref_sequences, - disable=self.disable_cache, - ) - # Detach to ensure no gradient flow to model - var_repr = var_repr.detach() - ref_repr = ref_repr.detach() - else: - var_repr = self.model._sequence_to_representative(variant_sequences) - ref_repr = self.model._sequence_to_representative(ref_sequences) - - # Concatenate variant and reference representations - sequence_repr = torch.cat([var_repr, ref_repr], dim=1) - else: - # Single sequence task: (sequences, labels, conditional_input) - sequences, labels, conditional_input = batch - labels = labels.to(self.device) - - # Get representative embeddings - # Use no_grad when only training projection layer (saves memory/compute) - if self.only_proj_layer: - with torch.no_grad(): - sequence_repr = fwd_cache.cached_call( - self.model._sequence_to_representative, - sequences, - disable=self.disable_cache, - ) - # Detach to ensure no gradient flow to model - sequence_repr = sequence_repr.detach() + # When only the projection layer is trained, freeze backbone params for + # the duration of training. Inputs (token ids) never require grad, so with + # the backbone frozen its forward yields requires_grad=False outputs + # naturally -- no autograd graph is built and the per-batch no_grad()/ + # detach() dance is unnecessary. Cache hits return fresh tensors rebuilt + # from numpy (also grad-free). State is restored in the finally block so + # the caller's model is not permanently mutated. + original_requires_grad: Optional[List[Tuple[torch.nn.Parameter, bool]]] = None + if self.only_proj_layer: + original_requires_grad = [ + (p, p.requires_grad) for p in self.model.parameters() + ] + for p, _ in original_requires_grad: + p.requires_grad_(False) + + avg_loss = float("nan") + try: + # Training loop + for epoch in range(self.num_epochs): + total_loss = 0.0 + num_batches = 0 + + progress_bar = tqdm(self.train_loader, desc=f"Fine-tuning epoch {epoch+1}/{self.num_epochs}") + for batch in progress_bar: + with _autocast(): + if self.is_variant_effect_prediction: + # Variant effect task: (variant_seqs, ref_seqs, labels, conditional_input) + variant_sequences, ref_sequences, labels, conditional_input = batch + labels = labels.to(self.device) + + if self.only_proj_layer: + var_repr = fwd_cache.cached_call( + self.model._sequence_to_representative, + variant_sequences, + disable=self.disable_cache, + ) + ref_repr = fwd_cache.cached_call( + self.model._sequence_to_representative, + ref_sequences, + disable=self.disable_cache, + ) + else: + var_repr = self.model._sequence_to_representative(variant_sequences) + ref_repr = self.model._sequence_to_representative(ref_sequences) + + # Concatenate variant and reference representations + sequence_repr = torch.cat([var_repr, ref_repr], dim=1) + else: + # Single sequence task: (sequences, labels, conditional_input) + sequences, labels, conditional_input = batch + labels = labels.to(self.device) + + if self.only_proj_layer: + sequence_repr = fwd_cache.cached_call( + self.model._sequence_to_representative, + sequences, + disable=self.disable_cache, + ) + else: + sequence_repr = self.model._sequence_to_representative(sequences) + + logits = self.projection(sequence_repr) + loss = criterion(logits, labels) + + optimizer.zero_grad(set_to_none=True) + if needs_scaler: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() else: - sequence_repr = self.model._sequence_to_representative(sequences) - - logits = self.projection(sequence_repr) - - loss = criterion(logits, labels) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # Update loss tracking - total_loss += loss.item() - num_batches += 1 - avg_loss = total_loss / num_batches - - # Update progress bar with average loss - progress_bar.set_postfix({'avg_loss': f'{avg_loss:.4f}'}) - - if fwd_cache is not None: - fwd_cache.clear() + loss.backward() + optimizer.step() + + # Update loss tracking + total_loss += loss.item() + num_batches += 1 + avg_loss = total_loss / num_batches + + # Update progress bar with average loss + progress_bar.set_postfix({'avg_loss': f'{avg_loss:.4f}'}) + finally: + if original_requires_grad is not None: + for p, orig in original_requires_grad: + p.requires_grad_(orig) + + if fwd_cache is not None: + fwd_cache.clear() if self.num_epochs > 0: print(f"Fine-tuning completed. Final average loss: {avg_loss:.4f}")