Skip to content

Commit fd75045

Browse files
committed
added gpt-2 fine tune training script
1 parent c79f971 commit fd75045

File tree

7 files changed

+240
-16
lines changed

7 files changed

+240
-16
lines changed

notebooks/test.ipynb

Whitespace-only changes.

src/cli/01_train_teacher.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from src.models import build_teacher
1313
from src.data import ClassificationDataModule
14+
from utils.wandb_setup import setup_wandb
1415

1516

1617
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -38,16 +39,7 @@ def main(config_path: Path = type.Argument(..., help="Path to YAML config")):
3839
cfg = yaml.safe_load(config_path.read_text())
3940

4041
# --- SETUP W&B ---
41-
run_name = cfg['training'].get("run_name", f"teacher_train_{Path(cfg['training']['output_dir']).name}")
42-
report_to = cfg['training'].get("report_to", "none") # Default to no reporting
43-
if report_to == "wandb":
44-
project_name = cfg['training'].get("wandb_project", "senti_synth_teacher")
45-
os.environ.pop("WANDB_DISABLED", None) # Ensure it's enabled if requested
46-
os.environ["WANDB_PROJECT"] = project_name
47-
logger.info(f"Reporting to W&B project: {project_name}")
48-
else:
49-
os.environ["WANDB_DISABLED"] = "true" # Explicitly disable
50-
logger.info("W&B reporting disabled.")
42+
run_name, report_to = setup_wandb(cfg)
5143

5244
# --- BUILD MODEL ---
5345
model, tokenizer = build_teacher(cfg['model'])

src/cli/02_fine_tune_generator.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import typer
2+
import yaml
3+
from pathlib import Path
4+
import logging
5+
6+
import torch
7+
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments, IntervalStrategy
8+
9+
from utils.wandb_setup import setup_wandb
10+
from utils.metrics import compute_metrics
11+
from models import build_generator
12+
from data import GeneratorDataModule
13+
14+
15+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
16+
logger = logging.getLogger(__name__)
17+
18+
app = typer.Typer()
19+
20+
21+
@app.command()
22+
def main(config_path: Path = type.Argument(..., help="Path to YAML config")):
23+
cfg = yaml.safe_load(config_path.read_text())
24+
25+
# --- SETUP W&B ---
26+
run_name, report_to = setup_wandb(cfg)
27+
28+
# --- BUILD MODEL ---
29+
model, tokenizer = build_generator(cfg['model'])
30+
31+
# --- SETUP DATA ---
32+
data_module = GeneratorDataModule(cfg['data'], tokenizer)
33+
data_module.setup()
34+
35+
train_dataset = data_module.get_train_dataset()
36+
eval_dataset = data_module.get_eval_dataset()
37+
38+
# --- SETUP TRAINER ---
39+
data_collator = DataCollatorForLanguageModeling(
40+
tokenizer=tokenizer,
41+
mlm=False,
42+
)
43+
44+
training_args_dict = {
45+
"output_dir": cfg['training']['output_dir'],
46+
"overwrite_output_dir": cfg['training'].get("overwrite_output_dir", True),
47+
"do_train": True,
48+
"do_eval": eval_dataset is not None,
49+
"per_device_train_batch_size": cfg['training'].get("per_device_train_batch_size", 8),
50+
"per_device_eval_batch_size": cfg['training'].get("per_device_eval_batch_size", 16),
51+
"gradient_accumulation_steps": cfg['training'].get("gradient_accumulation_steps", 1),
52+
"num_train_epochs": cfg['training'].get("num_train_epochs", 3),
53+
"learning_rate": cfg['training'].get("learning_rate", 5e-5),
54+
"warmup_ratio": cfg['training'].get("warmup_ratio", 0.1),
55+
"fp16": cfg['training'].get("fp16", torch.cuda.is_available()),
56+
"logging_dir": cfg['training'].get("logging_dir", f"{cfg['training']['output_dir']}/logs"),
57+
"logging_steps": cfg['training'].get("logging_steps", 100),
58+
"eval_strategy": IntervalStrategy.STEPS if eval_dataset is not None else IntervalStrategy.NO,
59+
"eval_steps": cfg['training'].get("eval_steps", 500),
60+
"save_strategy": IntervalStrategy.STEPS,
61+
"save_steps": cfg['training'].get("save_steps", 500),
62+
"save_total_limit": cfg['training'].get("save_total_limit", 2),
63+
"load_best_model_at_end": cfg['training'].get("load_best_model_at_end", eval_dataset is not None),
64+
"metric_for_best_model": cfg['training'].get("metric_for_best_model", "eval_loss" if eval_dataset else None),
65+
"greater_is_better": cfg['training'].get("greater_is_better", False),
66+
"report_to": [report_to] if report_to != "none" else [],
67+
"run_name": run_name,
68+
"remove_unused_columns": False,
69+
"ddp_find_unused_parameters": cfg['training'].get("ddp_find_unused_parameters", False),
70+
}
71+
72+
training_args = TrainingArguments(**training_args_dict)
73+
logger.info(f"Training arguments: {training_args}. FP16 Enabled: {training_args.fp16}")
74+
75+
trainer = Trainer(
76+
model=model,
77+
args=training_args,
78+
train_dataset=train_dataset,
79+
eval_dataset=eval_dataset,
80+
tokenizer=tokenizer,
81+
data_collator=data_collator,
82+
compute_metrics=compute_metrics if eval_dataset is not None else None,
83+
)
84+
85+
# --- TRAIN ---
86+
logger.info("Training model...")
87+
train_result = trainer.train()
88+
logger.info(f"Training results: {train_result}")
89+
90+
# Save final model & metrics
91+
logger.info(f"Saving best model to {training_args.output_dir}")
92+
trainer.save_model() # Saves the best model due to load_best_model_at_end=True
93+
trainer.save_state()
94+
95+
# Log final metrics
96+
metrics = train_result.metrics
97+
trainer.log_metrics("train", metrics)
98+
trainer.save_metrics("train", metrics)
99+
100+
# Evaluate on test set if available
101+
test_dataset = data_module.get_test_dataset()
102+
if test_dataset and cfg['training'].get("do_test_eval", True):
103+
logger.info("Evaluating on test set...")
104+
test_metrics = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test")
105+
trainer.log_metrics("test", test_metrics)
106+
trainer.save_metrics("test", test_metrics)
107+
logger.info(f"Test set evaluation complete: {test_metrics}")
108+
109+
110+
logger.info("Script finished successfully.")
111+
112+
113+
if __name__ == "__main__":
114+
app()

src/data.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from datasets import load_from_disk, DatasetDict
3-
from transformers import AutoTokenizer
3+
from transformers import AutoTokenizer, GPT2Tokenizer
44
logger = logging.getLogger(__name__)
55

66

@@ -76,3 +76,76 @@ def get_test_dataset(self):
7676
if not self.tokenized_datasets: self.setup() # noqa: E701
7777
return self.tokenized_datasets["test"]
7878

79+
80+
81+
class GeneratorDataModule:
82+
"""
83+
Data module for generative fine-tuning tasks.
84+
Handles text generation setup.
85+
"""
86+
def __init__(self, cfg: dict, tokenizer: GPT2Tokenizer):
87+
self.cfg = cfg
88+
self.tokenizer = tokenizer
89+
self.dataset_path = cfg.get("dataset_path", None)
90+
91+
self.max_len = cfg.get("max_len", 32)
92+
93+
self.tokenized_datasets = None
94+
95+
self.required_splits = ["train", "val", "sanity", "test"]
96+
self.text_column = "text"
97+
98+
def _load_clean_dataset(self) -> DatasetDict:
99+
logger.info(f"Loading dataset from: {self.dataset_path}")
100+
dataset = load_from_disk(self.dataset_path)
101+
102+
missing_splits = [s for s in self.required_splits if s not in dataset]
103+
if missing_splits:
104+
raise ValueError(f"Dataset missing splits: {missing_splits}")
105+
106+
return dataset
107+
108+
def _tokenize_function(self, examples):
109+
"""Tokenization function for map."""
110+
return self.tokenizer(
111+
examples[self.text_column],
112+
truncation=True,
113+
padding=False,
114+
max_length=self.max_len
115+
)
116+
117+
def setup(self):
118+
"""Loads and tokenizes the dataset."""
119+
if self.tokenized_datasets:
120+
return
121+
122+
raw_datasets = self._load_clean_dataset()
123+
124+
125+
self.tokenized_datasets = raw_datasets.map(
126+
self._tokenize_function,
127+
batched=True,
128+
remove_columns=[c for c in raw_datasets["train"].column_names if c not in
129+
["input_ids", "attention_mask", "labels"]]
130+
)
131+
132+
logger.info(f"Loaded and tokenized datasets with max length: {self.max_len}")
133+
logger.info(f"Columns in tokenized datasets: {self.tokenized_datasets[self.required_splits[0]].column_names}")
134+
135+
136+
def get_train_dataset(self):
137+
if not self.tokenized_datasets: self.setup() # noqa: E701
138+
return self.tokenized_datasets["train"]
139+
140+
def get_eval_dataset(self):
141+
if not self.tokenized_datasets: self.setup() # noqa: E701
142+
return self.tokenized_datasets["val"]
143+
144+
def get_sanity_dataset(self):
145+
if not self.tokenized_datasets: self.setup() # noqa: E701
146+
return self.tokenized_datasets["sanity"]
147+
148+
def get_test_dataset(self):
149+
if not self.tokenized_datasets: self.setup() # noqa: E701
150+
return self.tokenized_datasets["test"]
151+

src/models.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import logging
55
from transformers import AutoModelForSequenceClassification, AutoTokenizer
6+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
67

78
logger = logging.getLogger(__name__)
89

@@ -31,4 +32,25 @@ def build_teacher(cfg: dict):
3132
logger.info(f"Loading tokenizer for: {model_name}")
3233
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=use_fast_tokenizer)
3334

35+
return model, tokenizer
36+
37+
38+
def build_generator(cfg: dict):
39+
"""
40+
Builds and returns the generator model and tokenizer using Hugging Face.
41+
42+
Args:
43+
cfg (dict): Configuration dictionary for the model, expecting keys like:
44+
45+
Returns:
46+
tuple: (model, tokenizer)
47+
"""
48+
model_name = cfg.get("model_name", "gpt2")
49+
use_fast_tokenizer = cfg.get("use_fast_tokenizer", True)
50+
51+
logger.info(f"Loading generator model: {model_name}")
52+
model = GPT2LMHeadModel.from_pretrained(model_name)
53+
tokenizer = GPT2Tokenizer.from_pretrained(model_name, use_fast=use_fast_tokenizer)
54+
tokenizer.pad_token = tokenizer.eos_token
55+
3456
return model, tokenizer

src/utils/metrics.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import numpy as np
2+
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
3+
4+
5+
def compute_metrics(p):
6+
"""Computes metrics for HF Trainer."""
7+
preds = np.argmax(p.predictions, axis=1)
8+
labels = p.label_ids
9+
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary') # Assuming binary
10+
acc = accuracy_score(labels, preds)
11+
return {
12+
'accuracy': acc,
13+
'f1': f1,
14+
'precision': precision,
15+
'recall': recall
16+
}

src/utils/wandb_setup.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
import os
22
import logging
3-
3+
from pathlib import Path
44
logger = logging.getLogger(__name__)
55

66

77
def setup_wandb(cfg: dict):
88
"""Setup WANDB for logging."""
9-
project_name = cfg.get("project_name", "senti-synth-teacher")
10-
os.environ.pop("WANDB_DISABLED", None)
11-
os.environ["WANDB_PROJECT"] = project_name
12-
logger.info(f"Reporting to W&B project: {project_name}")
9+
run_name = cfg['training'].get("run_name", f"teacher_train_{Path(cfg['training']['output_dir']).name}")
10+
report_to = cfg['training'].get("report_to", "none") # Default to no reporting
11+
if report_to == "wandb":
12+
project_name = cfg['training'].get("wandb_project", "senti_synth_teacher")
13+
os.environ.pop("WANDB_DISABLED", None) # Ensure it's enabled if requested
14+
os.environ["WANDB_PROJECT"] = project_name
15+
logger.info(f"Reporting to W&B project: {project_name}")
16+
else:
17+
os.environ["WANDB_DISABLED"] = "true" # Explicitly disable
18+
logger.info("W&B reporting disabled.")
19+
return run_name, report_to

0 commit comments

Comments
 (0)