1+ import typer
2+ import yaml
3+ from pathlib import Path
4+ import logging
5+
6+ import torch
7+ from transformers import DataCollatorForLanguageModeling , Trainer , TrainingArguments , IntervalStrategy
8+
9+ from utils .wandb_setup import setup_wandb
10+ from utils .metrics import compute_metrics
11+ from models import build_generator
12+ from data import GeneratorDataModule
13+
14+
15+ logging .basicConfig (level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(message)s" )
16+ logger = logging .getLogger (__name__ )
17+
18+ app = typer .Typer ()
19+
20+
21+ @app .command ()
22+ def main (config_path : Path = type .Argument (..., help = "Path to YAML config" )):
23+ cfg = yaml .safe_load (config_path .read_text ())
24+
25+ # --- SETUP W&B ---
26+ run_name , report_to = setup_wandb (cfg )
27+
28+ # --- BUILD MODEL ---
29+ model , tokenizer = build_generator (cfg ['model' ])
30+
31+ # --- SETUP DATA ---
32+ data_module = GeneratorDataModule (cfg ['data' ], tokenizer )
33+ data_module .setup ()
34+
35+ train_dataset = data_module .get_train_dataset ()
36+ eval_dataset = data_module .get_eval_dataset ()
37+
38+ # --- SETUP TRAINER ---
39+ data_collator = DataCollatorForLanguageModeling (
40+ tokenizer = tokenizer ,
41+ mlm = False ,
42+ )
43+
44+ training_args_dict = {
45+ "output_dir" : cfg ['training' ]['output_dir' ],
46+ "overwrite_output_dir" : cfg ['training' ].get ("overwrite_output_dir" , True ),
47+ "do_train" : True ,
48+ "do_eval" : eval_dataset is not None ,
49+ "per_device_train_batch_size" : cfg ['training' ].get ("per_device_train_batch_size" , 8 ),
50+ "per_device_eval_batch_size" : cfg ['training' ].get ("per_device_eval_batch_size" , 16 ),
51+ "gradient_accumulation_steps" : cfg ['training' ].get ("gradient_accumulation_steps" , 1 ),
52+ "num_train_epochs" : cfg ['training' ].get ("num_train_epochs" , 3 ),
53+ "learning_rate" : cfg ['training' ].get ("learning_rate" , 5e-5 ),
54+ "warmup_ratio" : cfg ['training' ].get ("warmup_ratio" , 0.1 ),
55+ "fp16" : cfg ['training' ].get ("fp16" , torch .cuda .is_available ()),
56+ "logging_dir" : cfg ['training' ].get ("logging_dir" , f"{ cfg ['training' ]['output_dir' ]} /logs" ),
57+ "logging_steps" : cfg ['training' ].get ("logging_steps" , 100 ),
58+ "eval_strategy" : IntervalStrategy .STEPS if eval_dataset is not None else IntervalStrategy .NO ,
59+ "eval_steps" : cfg ['training' ].get ("eval_steps" , 500 ),
60+ "save_strategy" : IntervalStrategy .STEPS ,
61+ "save_steps" : cfg ['training' ].get ("save_steps" , 500 ),
62+ "save_total_limit" : cfg ['training' ].get ("save_total_limit" , 2 ),
63+ "load_best_model_at_end" : cfg ['training' ].get ("load_best_model_at_end" , eval_dataset is not None ),
64+ "metric_for_best_model" : cfg ['training' ].get ("metric_for_best_model" , "eval_loss" if eval_dataset else None ),
65+ "greater_is_better" : cfg ['training' ].get ("greater_is_better" , False ),
66+ "report_to" : [report_to ] if report_to != "none" else [],
67+ "run_name" : run_name ,
68+ "remove_unused_columns" : False ,
69+ "ddp_find_unused_parameters" : cfg ['training' ].get ("ddp_find_unused_parameters" , False ),
70+ }
71+
72+ training_args = TrainingArguments (** training_args_dict )
73+ logger .info (f"Training arguments: { training_args } . FP16 Enabled: { training_args .fp16 } " )
74+
75+ trainer = Trainer (
76+ model = model ,
77+ args = training_args ,
78+ train_dataset = train_dataset ,
79+ eval_dataset = eval_dataset ,
80+ tokenizer = tokenizer ,
81+ data_collator = data_collator ,
82+ compute_metrics = compute_metrics if eval_dataset is not None else None ,
83+ )
84+
85+ # --- TRAIN ---
86+ logger .info ("Training model..." )
87+ train_result = trainer .train ()
88+ logger .info (f"Training results: { train_result } " )
89+
90+ # Save final model & metrics
91+ logger .info (f"Saving best model to { training_args .output_dir } " )
92+ trainer .save_model () # Saves the best model due to load_best_model_at_end=True
93+ trainer .save_state ()
94+
95+ # Log final metrics
96+ metrics = train_result .metrics
97+ trainer .log_metrics ("train" , metrics )
98+ trainer .save_metrics ("train" , metrics )
99+
100+ # Evaluate on test set if available
101+ test_dataset = data_module .get_test_dataset ()
102+ if test_dataset and cfg ['training' ].get ("do_test_eval" , True ):
103+ logger .info ("Evaluating on test set..." )
104+ test_metrics = trainer .evaluate (eval_dataset = test_dataset , metric_key_prefix = "test" )
105+ trainer .log_metrics ("test" , test_metrics )
106+ trainer .save_metrics ("test" , test_metrics )
107+ logger .info (f"Test set evaluation complete: { test_metrics } " )
108+
109+
110+ logger .info ("Script finished successfully." )
111+
112+
113+ if __name__ == "__main__" :
114+ app ()
0 commit comments