From 99c698edae06f4d763ca9825dc0ec04b86ae5a34 Mon Sep 17 00:00:00 2001 From: Dan Date: Sun, 15 Feb 2026 11:42:44 +0100 Subject: [PATCH] Add HF-only LoRA training path without llm-foundry deps --- README.md | 9 +- configs/finetuning/lora.yaml | 24 ++ pyproject.toml | 8 + scripts/README.md | 5 + scripts/prepare_train_eval.sh | 9 +- scripts/train_lora.sh | 28 +++ src/panza/finetuning/train.py | 136 +++++++--- src/panza/finetuning/train_lora_hf.py | 341 ++++++++++++++++++++++++++ 8 files changed, 530 insertions(+), 30 deletions(-) create mode 100644 configs/finetuning/lora.yaml create mode 100755 scripts/train_lora.sh create mode 100644 src/panza/finetuning/train_lora_hf.py diff --git a/README.md b/README.md index 2150033..e1833a5 100755 --- a/README.md +++ b/README.md @@ -87,6 +87,10 @@ If you want to also finetune models using Panza, you will need to install additi ``` bash pip install .[training] ``` +For standard LoRA-only fine-tuning without RoSA/spops dependencies, install: +``` bash +pip install .[training_lora] +``` ## :rocket: Getting started @@ -177,7 +181,8 @@ Run `CUDA_VISIBLE_DEVICES=X ./prepare_data.sh`.
We currently support `LLaMA3-8B-Instruct` and `Mistral-Instruct-v0.2` LLMs as base models; the former is the default, but we obtained good results with either model. -1. [Recommended] For parameter efficient fine-tuning, run `./train_rosa.sh`. +1. [Recommended] For parameter efficient fine-tuning, run `./train_rosa.sh`. +If you want standard LoRA only (no RoSA sparse masks, no spops, no llm-foundry/composer), run `./train_lora.sh`. If a larger GPU is available and full-parameter fine-tuning is possible, run `./train_fft.sh`. 2. We have prepopulated the training configs with parameter values that worked best for us. We recommend you try those first, but you can also experiment with different hyper-parameters by passing extra arguments to the training script, such as `lr`, `lora_lr`, `num_epochs`. All the trained models are saved in the `checkpoints` directory. @@ -187,6 +192,8 @@ Examples: CUDA_VISIBLE_DEVICES=X ./train_rosa.sh # Will use the default parameters. CUDA_VISIBLE_DEVICES=X ./train_rosa.sh finetuning.lr=1e-6 finetuning.rosa_lr=1e-6 finetuning.max_duration=7ep + +CUDA_VISIBLE_DEVICES=X ./train_lora.sh finetuning.lr=1e-6 finetuning.lora.lora_lr=1e-6 finetuning.max_duration=7ep ``` On a smaller GPU, it may be necessary to further train in lower precision (QRoSA). This can be run as follows: diff --git a/configs/finetuning/lora.yaml b/configs/finetuning/lora.yaml new file mode 100644 index 0000000..3fb212c --- /dev/null +++ b/configs/finetuning/lora.yaml @@ -0,0 +1,24 @@ +defaults: + - base + +max_duration: 5ep +lr: 1e-5 +batch_size: 8 +eval_interval: 1 +seed: ${seed} +model_name_or_path: "ISTA-DASLab/Meta-Llama-3-8B-Instruct" +save_merged_model: False + +lora: + lora_lr: ${finetuning.lr} + r: 8 + lora_alpha: 16 + target_modules: all-linear + lora_dropout: 0.05 + bias: none + task_type: CAUSAL_LM + +scheduler: + t_warmup: 8ba + +num_cpu_threads: 1 diff --git a/pyproject.toml b/pyproject.toml index f1675f5..de81793 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,14 @@ training = [ "peft@git+https://github.com/IST-DASLab/peft-rosa.git@grad_quant_looser_versioning", "spops-sm-80", ] +training_lora = [ + "langdetect", + "fire", + "gradio", + "cmake", + "packaging", + "peft", +] contributing = [ "pre-commit", ] diff --git a/scripts/README.md b/scripts/README.md index 400efc6..5772b1f 100755 --- a/scripts/README.md +++ b/scripts/README.md @@ -17,6 +17,7 @@ This directory contains all scripts necessary to train and run Panza. We provide #### Training * `train_rosa.sh` performs [parameter-efficient training](https://arxiv.org/pdf/2401.04679.pdf). +* `train_lora.sh` performs standard LoRA parameter-efficient training directly through Hugging Face Transformers (no RoSA masks/sparsity and no llm-foundry/composer dependency). * `train_fft.sh` performs full-parameter/full-rank training. _Note that this requires additional computational resources (about 2x)._ @@ -70,6 +71,10 @@ and ```bash pip install panza_mail[training] +``` +For standard LoRA-only fine-tuning without RoSA/spops dependencies, install: +```bash +pip install panza_mail[training_lora] ``` #### Inference diff --git a/scripts/prepare_train_eval.sh b/scripts/prepare_train_eval.sh index a7e8574..6161197 100755 --- a/scripts/prepare_train_eval.sh +++ b/scripts/prepare_train_eval.sh @@ -53,4 +53,11 @@ elif [[ $training_mode == "full" ]]; then echo "Generating json evaluation" python runner.py interfaces=json writer/llm=transformers fi -fi \ No newline at end of file +elif [[ $training_mode == "lora" ]]; then + python ../src/panza/finetuning/train_lora_hf.py \ + finetuning=lora ${vars[@]} + if [[ $test_split != "0" ]]; then + echo "Generating json evaluation" + python runner.py interfaces=json writer/llm=peft + fi +fi diff --git a/scripts/train_lora.sh b/scripts/train_lora.sh new file mode 100755 index 0000000..4eb70ce --- /dev/null +++ b/scripts/train_lora.sh @@ -0,0 +1,28 @@ +# Convenience script for running standard LoRA finetuning. +# All arguments to the python script can be provided +# here exactly in the form they would be passed to the +# python script directly. +# +# Example usage: +# ./train_lora.sh user=alonso trainer.optimizer.lr=0.1 + +set -e + +vars=() +idx=1 + +# process input arguments +for argument in "$@" +do + key=$(echo $argument | cut -f1 -d=) + + if [[ $key == finetuning ]]; then + echo "The 'finetuning' argument is already set and should not be overridden here; override is ignored." + else + vars[idx]=$argument + idx+=1 + fi +done + +python ../src/panza/finetuning/train_lora_hf.py \ + finetuning=lora ${vars[@]} diff --git a/src/panza/finetuning/train.py b/src/panza/finetuning/train.py index 8f742a7..4309a74 100755 --- a/src/panza/finetuning/train.py +++ b/src/panza/finetuning/train.py @@ -12,7 +12,6 @@ import warnings from pathlib import Path from typing import Any, Dict, List, Optional, Union -import spops import torch from composer import Trainer @@ -38,8 +37,7 @@ from llmfoundry.utils import find_mosaicml_logger, log_train_analytics, maybe_create_mosaicml_logger from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om -from peft import get_peft_model -from peft.tuners.rosa import RosaConfig, RosaModel, RosaScheduler +from peft import LoraConfig, get_peft_model from rich.traceback import install from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedTokenizerBase @@ -72,6 +70,30 @@ log = logging.getLogger(__name__) +try: + import spops +except ImportError: + spops = None + +try: + from peft.tuners.rosa import RosaConfig, RosaModel, RosaScheduler +except ImportError: + RosaConfig = None + RosaModel = None + RosaScheduler = None + + +def get_adapter_type(finetuning_cfg: DictConfig) -> Optional[str]: + has_rosa = "rosa" in finetuning_cfg + has_lora = "lora" in finetuning_cfg + if has_rosa and has_lora: + raise ValueError("Only one adapter mode can be enabled at a time (choose either rosa or lora).") + if has_rosa: + return "rosa" + if has_lora: + return "lora" + return None + def validate_config(cfg: DictConfig): """Validates compatible model and dataloader selection.""" @@ -156,10 +178,11 @@ def create_run_name(cfg: DictConfig) -> str: run_name += f"-{cfg.model_precision}" run_name += f"-bs{cfg.finetuning.batch_size}" - if hasattr(cfg.finetuning, "rosa"): - run_name += "-rosa" - else: + adapter_type = get_adapter_type(cfg.finetuning) + if adapter_type is None: run_name += "-fft" + else: + run_name += f"-{adapter_type}" run_name += f"-lr{cfg.finetuning.lr}" run_name += f"-{cfg.finetuning.max_duration}" @@ -234,12 +257,15 @@ def override_config(cfg: DictConfig) -> None: if not cfg.finetuning.run_name: cfg.finetuning.run_name = create_run_name(cfg) - if hasattr(cfg.finetuning, "rosa"): + adapter_type = get_adapter_type(cfg.finetuning) + if adapter_type == "rosa": cfg.finetuning.rosa.rosa_dtype = get_rosa_dtype(cfg) if cfg.finetuning.rosa.spa_d != 0: override_rosa_schedule(cfg, mask_generation=cfg.finetuning.rosa.masks_only) - else: - cfg.finetuning.callbacks.hf_checkpointer.precision = get_hf_save_precision(cfg) + elif adapter_type is None: + callbacks_cfg = cfg.finetuning.get("callbacks", None) + if callbacks_cfg is not None and "hf_checkpointer" in callbacks_cfg: + callbacks_cfg.hf_checkpointer.precision = get_hf_save_precision(cfg) # Re-enable struct mode to lock down the configuration OmegaConf.set_struct(cfg, True) @@ -253,9 +279,10 @@ def save_config_to_yaml(cfg: DictConfig) -> str: def build_composer_peft_model( - model_config: str, - rosa_config: Dict[str, Any], + model_config: DictConfig, tokenizer: PreTrainedTokenizerBase, + rosa_config: Optional[Dict[str, Any]] = None, + lora_config: Optional[Dict[str, Any]] = None, is_fsdp: bool = False, ) -> ComposerHFCausalLM: @@ -293,6 +320,10 @@ def build_composer_peft_model( print("Model built!") if rosa_config is not None: + if RosaConfig is None: + raise ImportError( + "RoSA fine-tuning requires peft-rosa. Install training dependencies with RoSA support." + ) print("Building RoSA config...") config = RosaConfig( r=rosa_config["lora_r"], @@ -319,6 +350,19 @@ def build_composer_peft_model( print("Adding RoSA modules...") model = get_peft_model(model, config) print("RoSA modules added!") + elif lora_config is not None: + print("Building LoRA config...") + config = LoraConfig( + r=lora_config.get("r", 8), + lora_alpha=lora_config.get("lora_alpha", 16), + target_modules=lora_config.get("target_modules", "all-linear"), + lora_dropout=lora_config.get("lora_dropout", 0.05), + bias=lora_config.get("bias", "none"), + task_type=lora_config.get("task_type", "CAUSAL_LM"), + ) + print("Adding LoRA modules...") + model = get_peft_model(model, config) + print("LoRA modules added!") train_metrics = [LanguageCrossEntropy(), LanguagePerplexity()] eval_metrics = [ @@ -445,6 +489,11 @@ def main(cfg: DictConfig) -> Trainer: rosa_config: Optional[Dict[str, Any]] = pop_config( cfg, "rosa", must_exist=False, default_value=None, convert=True ) + lora_config: Optional[Dict[str, Any]] = pop_config( + cfg, "lora", must_exist=False, default_value=None, convert=True + ) + if rosa_config is not None and lora_config is not None: + raise ValueError("Both rosa and lora configs were provided. Select only one adapter mode.") hf_save_path: Union[int, str] = pop_config(cfg, "hf_save_path", must_exist=True) @@ -554,7 +603,13 @@ def main(cfg: DictConfig) -> Trainer: if num_cpu_threads > 0: print(f"Setting number of CPU threads to {num_cpu_threads}") torch.set_num_threads(num_cpu_threads) - spops.set_num_threads(num_cpu_threads) + if rosa_config is not None: + if spops is None: + warnings.warn( + "spops is not installed; skipping spops.set_num_threads for RoSA mode." + ) + else: + spops.set_num_threads(num_cpu_threads) # Enable autoresume from model checkpoints if possible autoresume_default: bool = False @@ -679,6 +734,7 @@ def main(cfg: DictConfig) -> Trainer: use_async_eval = any(isinstance(c, AsyncEval) for c in callbacks) print("ROSA CONFIG", rosa_config) + print("LORA CONFIG", lora_config) # Build Model print("Initializing model...") with init_context: @@ -686,9 +742,13 @@ def main(cfg: DictConfig) -> Trainer: fsdp_config is None or rosa_config is None ), "fsdp is cuurently not supported with RoSA" model = build_composer_peft_model( - model_config, rosa_config, tokenizer, is_fsdp=fsdp_config is not None + model_config, + tokenizer, + rosa_config=rosa_config, + lora_config=lora_config, + is_fsdp=fsdp_config is not None, ) - if rosa_config is not None: + if rosa_config is not None and RosaModel is not None: assert isinstance(model.model.base_model, RosaModel) # Algorithms @@ -702,6 +762,10 @@ def main(cfg: DictConfig) -> Trainer: ) if rosa_config is not None: + if RosaScheduler is None: + raise ImportError( + "RoSA fine-tuning requires peft-rosa. Install training dependencies with RoSA support." + ) algorithms.append(RosaScheduler(model.model.base_model)) # Dataloaders @@ -776,24 +840,40 @@ def main(cfg: DictConfig) -> Trainer: # Optimizer optimizer_name: str = optimizer_config.pop("name") - if rosa_config is None or "lora_lr" not in rosa_config: + adapter_config = rosa_config if rosa_config is not None else lora_config + adapter_lr = adapter_config.get("lora_lr") if adapter_config is not None else None + if adapter_lr is None: optimizer = build_optimizer(model, optimizer_name, optimizer_config) else: - print(f'Using a different learning rate for lora params {rosa_config["lora_lr"]}') + print(f"Using a different learning rate for LoRA params {adapter_lr}") assert optimizer_name == "decoupled_adamw" lora_params = [] other_params = [] + adapter_param_keys = ( + ["rosa_A", "rosa_B", "rosa_embedding_A", "rosa_embedding_B"] + if rosa_config is not None + else ["lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B"] + ) for name, param in model.named_parameters(): - if any( - [k in name for k in ["rosa_A", "rosa_B", "rosa_embedding_A", "rosa_embedding_B"]] - ): + if not param.requires_grad: + continue + if any([k in name for k in adapter_param_keys]): lora_params.append(param) else: other_params.append(param) - print(f"Found {len(lora_params)} lora params and {len(other_params)} other params") - params = [{"params": other_params}, {"params": lora_params, "lr": rosa_config["lora_lr"]}] - optimizer = DecoupledAdamW(params, **optimizer_config) + if len(lora_params) == 0: + warnings.warn( + "No LoRA parameters were detected for split learning rates; " + "falling back to the base optimizer config." + ) + optimizer = build_optimizer(model, optimizer_name, optimizer_config) + else: + print(f"Found {len(lora_params)} LoRA params and {len(other_params)} other params") + params = [{"params": lora_params, "lr": adapter_lr}] + if len(other_params) > 0: + params.insert(0, {"params": other_params}) + optimizer = DecoupledAdamW(params, **optimizer_config) # Now add the eval metrics try: @@ -878,19 +958,19 @@ def main(cfg: DictConfig) -> Trainer: # subdirectory that the HF writer wrote it into, and into # our desired and expected location. Only needed for full # (not low-rank) finetuning. - if rosa_config is None and torch.distributed.get_rank() == 0: + if rosa_config is None and lora_config is None and torch.distributed.get_rank() == 0: path_to_save = os.path.join(hf_save_path, run_name) hf_output_path = os.path.join(path_to_save, "huggingface") for filename in glob.glob(os.path.join(hf_output_path, "*", "*")): shutil.copy(filename, path_to_save) shutil.rmtree(os.path.join(hf_output_path)) - # if rosa is enabled, save the model manually, since - # llm-foundry's checkpointing doesn't work properly with RoSA - if rosa_config is not None: - assert fsdp_config is None, "fsdp is currently not supported with RoSA" + # If PEFT is enabled, save adapters manually. + if rosa_config is not None or lora_config is not None: + if rosa_config is not None: + assert fsdp_config is None, "fsdp is currently not supported with RoSA" path_to_save = os.path.join(hf_save_path, run_name) - print(f"saving the model to {path_to_save}") + print(f"Saving the model to {path_to_save}") if torch.distributed.get_rank() == 0: model.model.save_pretrained( path_to_save, is_main_process=True, state_dict=model.model.state_dict() diff --git a/src/panza/finetuning/train_lora_hf.py b/src/panza/finetuning/train_lora_hf.py new file mode 100644 index 0000000..17a4107 --- /dev/null +++ b/src/panza/finetuning/train_lora_hf.py @@ -0,0 +1,341 @@ +import math +import os +import tempfile +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import hydra +import torch +from datasets import load_dataset +from omegaconf import DictConfig, OmegaConf +from peft import LoraConfig, get_peft_model +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + Trainer, + TrainingArguments, + get_scheduler, + set_seed, +) + +from panza import PanzaWriter # The import also loads custom Hydra resolvers. + + +def create_lora_run_name(cfg: DictConfig) -> str: + run_name = f"panza_{cfg.user.username}" + model_name = cfg.finetuning.model_name_or_path.split("/")[-1] + run_name += f"-{model_name}" + run_name += f"-{cfg.model_precision}" + run_name += f"-bs{cfg.finetuning.batch_size}" + run_name += "-lora" + run_name += f"-lr{cfg.finetuning.lr}" + run_name += f"-{cfg.finetuning.max_duration}" + run_name += f"-seed{cfg.finetuning.seed}" + return run_name + + +def parse_num_epochs(max_duration: Any) -> float: + if isinstance(max_duration, (int, float)): + return float(max_duration) + if isinstance(max_duration, str) and max_duration.endswith("ep"): + return float(max_duration[:-2]) + raise ValueError( + f"Unsupported finetuning.max_duration value: {max_duration}. " + "For LoRA HF training, use values like '5ep'." + ) + + +def parse_warmup_steps(t_warmup: Any) -> int: + if isinstance(t_warmup, int): + return t_warmup + if isinstance(t_warmup, str) and t_warmup.endswith("ba"): + return int(t_warmup[:-2]) + return 0 + + +def save_config_to_yaml(cfg: DictConfig) -> str: + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + with tempfile.NamedTemporaryFile("w", delete=False, suffix=".yaml") as temp_file: + OmegaConf.save(config=cfg_dict, f=temp_file.name) + return temp_file.name + + +def get_model_dtype(cfg: DictConfig) -> torch.dtype: + if cfg.model_precision == "bf16": + return torch.bfloat16 + if cfg.model_precision == "fp32": + return torch.float32 + if cfg.model_precision == "4bit": + return torch.bfloat16 + raise ValueError(f"Unsupported model_precision: {cfg.model_precision}") + + +def get_quantization_config(model_cfg: DictConfig) -> Optional[BitsAndBytesConfig]: + weight_bias_dtype = model_cfg.get("weight_bias_dtype", None) + if weight_bias_dtype == "4bit": + return BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) + return None + + +def tokenize_example(example: Dict[str, Any], tokenizer: Any, max_seq_len: int) -> Dict[str, Any]: + prompt_ids = tokenizer(example["prompt"], add_special_tokens=False)["input_ids"] + response_ids = tokenizer(example["response"], add_special_tokens=False)["input_ids"] + + input_ids = prompt_ids + response_ids + labels = ([-100] * len(prompt_ids)) + response_ids + + if len(input_ids) > max_seq_len: + input_ids = input_ids[-max_seq_len:] + labels = labels[-max_seq_len:] + + attention_mask = [1] * len(input_ids) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + + +@dataclass +class CausalLMDataCollator: + pad_token_id: int + label_pad_token_id: int = -100 + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + max_len = max(len(feature["input_ids"]) for feature in features) + + input_ids: List[List[int]] = [] + attention_masks: List[List[int]] = [] + labels: List[List[int]] = [] + for feature in features: + pad_len = max_len - len(feature["input_ids"]) + input_ids.append(feature["input_ids"] + [self.pad_token_id] * pad_len) + attention_masks.append(feature["attention_mask"] + [0] * pad_len) + labels.append(feature["labels"] + [self.label_pad_token_id] * pad_len) + + return { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "attention_mask": torch.tensor(attention_masks, dtype=torch.long), + "labels": torch.tensor(labels, dtype=torch.long), + } + + +def build_optimizer( + model: torch.nn.Module, + optimizer_cfg: DictConfig, + base_lr: float, + lora_lr: Optional[float], +) -> torch.optim.Optimizer: + optimizer_name = optimizer_cfg.get("name", "decoupled_adamw") + if optimizer_name != "decoupled_adamw": + raise ValueError( + f"Unsupported optimizer '{optimizer_name}' for HF LoRA training. " + "Use decoupled_adamw." + ) + + betas_cfg = optimizer_cfg.get("betas", [0.9, 0.999]) + betas: Tuple[float, float] = (float(betas_cfg[0]), float(betas_cfg[1])) + eps = float(optimizer_cfg.get("eps", 1e-8)) + weight_decay = float(optimizer_cfg.get("weight_decay", 0.0)) + + if lora_lr is None: + trainable_params = [param for param in model.parameters() if param.requires_grad] + return torch.optim.AdamW( + trainable_params, + lr=base_lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + + lora_params = [] + other_params = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + if any(key in name for key in ["lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B"]): + lora_params.append(param) + else: + other_params.append(param) + + if not lora_params: + trainable_params = [param for param in model.parameters() if param.requires_grad] + return torch.optim.AdamW( + trainable_params, + lr=base_lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + + param_groups: List[Dict[str, Any]] = [{"params": lora_params, "lr": float(lora_lr)}] + if other_params: + param_groups.insert(0, {"params": other_params, "lr": base_lr}) + + return torch.optim.AdamW( + param_groups, + betas=betas, + eps=eps, + weight_decay=weight_decay, + ) + + +@hydra.main(version_base="1.1", config_path="../../../configs", config_name="panza_finetuning") +def main(cfg: DictConfig) -> None: + OmegaConf.set_struct(cfg, False) + if "lora" not in cfg.finetuning: + raise ValueError("This trainer only supports finetuning=lora.") + if "rosa" in cfg.finetuning: + raise ValueError("This trainer does not support RoSA. Use scripts/train_rosa.sh instead.") + + if not cfg.finetuning.run_name: + cfg.finetuning.run_name = create_lora_run_name(cfg) + OmegaConf.resolve(cfg) + + cfg.preprocessing.model = cfg.finetuning.model_name_or_path + if "retriever" in cfg.preprocessing.prompting: + # LoRA training does not require RAG retrieval and should not require FAISS assets. + cfg.preprocessing.prompting.retriever = OmegaConf.create( + {"_target_": "panza.retriever.NoneRetriever"} + ) + preprocessing_yaml = save_config_to_yaml(cfg.preprocessing) + + os.environ["PANZA_PREPROCESSING_CONFIG"] = preprocessing_yaml + os.environ["WANDB_PROJECT"] = f"panza-{cfg.user.username}" + os.environ["WANDB_DISABLED"] = str(int(cfg.finetuning.wandb_disabled)) + + set_seed(int(cfg.finetuning.seed)) + + from panza.finetuning.preprocessing import panza_preprocessing_function + + train_file = os.path.join(cfg.user.data_dir, "train.jsonl") + if not os.path.exists(train_file): + raise FileNotFoundError(f"Training data not found at {train_file}") + + tokenizer = AutoTokenizer.from_pretrained( + cfg.finetuning.model_name_or_path, + model_max_length=cfg.finetuning.max_seq_len, + trust_remote_code=True, + ) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + model_dtype = get_model_dtype(cfg) + quantization_config = get_quantization_config(cfg.finetuning.model) + + model = AutoModelForCausalLM.from_pretrained( + cfg.finetuning.model_name_or_path, + torch_dtype=model_dtype, + quantization_config=quantization_config, + trust_remote_code=True, + use_cache=False, + device_map="auto" if quantization_config is not None else None, + attn_implementation="eager", + ) + + lora_cfg = cfg.finetuning.lora + peft_config = LoraConfig( + r=int(lora_cfg.get("r", 8)), + lora_alpha=int(lora_cfg.get("lora_alpha", 16)), + target_modules=lora_cfg.get("target_modules", "all-linear"), + lora_dropout=float(lora_cfg.get("lora_dropout", 0.05)), + bias=lora_cfg.get("bias", "none"), + task_type=lora_cfg.get("task_type", "CAUSAL_LM"), + ) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + train_dataset = load_dataset("json", data_files=train_file, split="train") + train_dataset = train_dataset.map( + panza_preprocessing_function, + remove_columns=train_dataset.column_names, + num_proc=1, + ) + train_dataset = train_dataset.map( + lambda example: tokenize_example( + example, + tokenizer=tokenizer, + max_seq_len=int(cfg.finetuning.max_seq_len), + ), + remove_columns=train_dataset.column_names, + num_proc=1, + ) + + per_device_train_batch_size = int(cfg.finetuning.get("device_train_microbatch_size", 1)) + target_batch_size = int(cfg.finetuning.batch_size) + gradient_accumulation_steps = max(1, math.ceil(target_batch_size / per_device_train_batch_size)) + + num_train_epochs = parse_num_epochs(cfg.finetuning.max_duration) + warmup_steps = parse_warmup_steps(cfg.finetuning.scheduler.get("t_warmup", 0)) + + output_dir = os.path.join(cfg.finetuning.hf_save_path, cfg.finetuning.run_name) + os.makedirs(output_dir, exist_ok=True) + + training_args = TrainingArguments( + output_dir=output_dir, + per_device_train_batch_size=per_device_train_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + learning_rate=float(cfg.finetuning.lr), + num_train_epochs=num_train_epochs, + bf16=(cfg.finetuning.precision == "amp_bf16"), + fp16=False, + logging_strategy="steps", + logging_steps=1, + save_strategy="epoch", + report_to=[] if cfg.finetuning.wandb_disabled else ["wandb"], + remove_unused_columns=False, + ) + + optimizer = build_optimizer( + model=model, + optimizer_cfg=cfg.finetuning.optimizer, + base_lr=float(cfg.finetuning.lr), + lora_lr=( + float(cfg.finetuning.lora.lora_lr) + if "lora_lr" in cfg.finetuning.lora + else None + ), + ) + + num_update_steps_per_epoch = max( + 1, + math.ceil(len(train_dataset) / (per_device_train_batch_size * gradient_accumulation_steps)), + ) + num_training_steps = max(1, math.ceil(num_train_epochs * num_update_steps_per_epoch)) + scheduler = get_scheduler( + name="linear", + optimizer=optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=num_training_steps, + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + data_collator=CausalLMDataCollator(pad_token_id=tokenizer.pad_token_id), + optimizers=(optimizer, scheduler), + ) + + trainer.train() + + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + if bool(cfg.finetuning.get("save_merged_model", False)): + merged_dir = os.path.join(output_dir, "merged") + merged_model = model.merge_and_unload() + merged_model.save_pretrained(merged_dir) + tokenizer.save_pretrained(merged_dir) + + os.remove(preprocessing_yaml) + + +if __name__ == "__main__": + main()