Skip to content

Commit db8bd9a

Browse files
committed
add create dataset and create student merged datasets
1 parent d7f79d5 commit db8bd9a

File tree

11 files changed

+438
-37
lines changed

11 files changed

+438
-37
lines changed

configs/generator/sst2_hf.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ model:
22
model_name: "gpt2" # you can swap in "gpt2-medium" or "EleutherAI/pythia-70m" etc.
33
use_fast_tokenizer: true
44
block_size: 128 # Maximum sequence length after tokenisation
5+
unfrozen_layers: 4 # how many *top* transformer blocks stay trainable
56

67
data:
78
dataset_path: "./data/clean/" # Use HF dataset identifier
@@ -26,16 +27,16 @@ training:
2627
num_train_epochs: 3 # SST‑2 is tiny; 2–3 epochs suffice
2728

2829
# ── precision & speed ──────────────────────────────────────────────────────
29-
fp16: true # enable mixed precision
30-
bf16: false # turn off to avoid dual precision modes
30+
fp16: false # enable mixed precision
31+
bf16: true # turn off to avoid dual precision modes
3132
# torch_dtype: "auto" # (optional) lets HF pick fastest dtype
3233

3334
# ── optimiser & scheduler ─────────────────────────────────────────────────
34-
learning_rate: 0.00005 # good starting LR for GPT‑2 on small corpora
35+
learning_rate: 0.00003 # good starting LR for GPT‑2 on small corpora
3536
warmup_ratio: 0.1
3637

3738
# ── misc performance knobs ────────────────────────────────────────────────
38-
dataloader_num_workers: 4
39+
dataloader_num_workers: 8
3940
gradient_checkpointing: true # big memory win on GPT‑style decoders
4041
max_grad_norm: 1.0
4142

configs/synthetic/conf.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
model:
2+
# Generator checkpoint (fine‑tuned GPT‑2, etc.)
3+
ckpt_dir: "runs/generator/gpt2_sst2/checkpoint-3500"
4+
5+
teacher:
6+
ckpt_dir: "runs/teacher/deberta_v3_base/"
7+
min_confidence: 0.85 # accept sample only if score ≥ this value
8+
9+
data:
10+
output_dir: "data/synthetic_sst2"
11+
12+
# Target number of *accepted* samples (total across classes)
13+
n_samples_total: 20000
14+
15+
# Fractions must sum to 1.0
16+
split_ratio:
17+
train: 0.9
18+
val: 0.05
19+
test: 0.05
20+
21+
generation:
22+
batch_size: 128 # try 128 if max_new_tokens stays at 64
23+
temperature: 0.8 # a tad warmer improves diversity, less rejects
24+
num_return_sequences: 2 # doubles raw throughput with the same kernels

data/models/.gitkeep

Whitespace-only changes.

data/processed/.gitkeep

Whitespace-only changes.

data/raw/.gitkeep

Whitespace-only changes.

data/results/.gitkeep

Whitespace-only changes.

src/cli/02_fine_tune_generator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,21 @@ def main(
4040
# ------------------------------------------------------- MODEL & TOKENISER
4141
model, tokenizer = build_generator(cfg["model"])
4242

43+
# ➡️ 1. Add sentiment‑prefix tokens so the model always sees them first
44+
special_tokens = {"additional_special_tokens": ["<POS>", "<NEG>"]}
45+
tokenizer.add_special_tokens(special_tokens)
46+
model.resize_token_embeddings(len(tokenizer))
47+
48+
def freeze_lower_layers(m, keep_last=4):
49+
total = len(m.transformer.h)
50+
cutoff = total - keep_last
51+
for i, block in enumerate(m.transformer.h):
52+
if i < cutoff:
53+
for p in block.parameters():
54+
p.requires_grad = False
55+
56+
freeze_lower_layers(model, keep_last=cfg["model"].get("unfrozen_layers", 4))
57+
4358
# ----------------------------------------------------------------- DATA
4459
dm = GeneratorDataModule(cfg["data"], tokenizer)
4560
dm.setup()
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
#!/usr/bin/env python
2+
"""
3+
03_create_synthetic_dataset.py
4+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
5+
Generate a **quality‑controlled** synthetic SST‑2‑style dataset using a
6+
fine‑tuned causal‑LM (e.g. GPT‑2) **and** a teacher classifier. A sample is
7+
accepted only when the teacher's confidence exceeds a configurable
8+
threshold. The teacher's prediction provides the final label, so we are
9+
robust to generator drift (e.g., a <POS> prompt that meanders into a
10+
negative review).
11+
12+
Key update (2025‑05‑13)
13+
-----------------------
14+
* **FIXED**: Hugging Face `pipeline` returns a *list of lists* when you pass
15+
a batch of prompts. The script now flattens the nested structure, so we
16+
no longer hit `TypeError: list indices must be integers or slices, not str`.
17+
18+
Configuration (YAML)
19+
--------------------
20+
model:
21+
ckpt_dir: "runs/generator/gpt2_sst2/checkpoint-best"
22+
23+
a": "\n" teacher:
24+
ckpt_dir: "runs/teacher/deberta_v3_base/checkpoint-2000"
25+
min_confidence: 0.8
26+
27+
data:
28+
output_dir: "data/synthetic_sst2"
29+
n_samples_total: 20000
30+
split_ratio: {train: 0.9, val: 0.05, test: 0.05}
31+
32+
generation:
33+
max_new_tokens: 64
34+
temperature: 0.7
35+
top_k: 50
36+
top_p: 0.95
37+
repetition_penalty: 1.2
38+
seed: 42
39+
batch_size: 8
40+
"""
41+
42+
from __future__ import annotations
43+
44+
import json
45+
import logging
46+
import math
47+
import random
48+
import itertools, queue, threading, time
49+
from pathlib import Path
50+
from typing import Dict, List
51+
52+
import torch
53+
import typer
54+
import yaml
55+
from tqdm.auto import tqdm
56+
from transformers import (
57+
AutoModelForCausalLM,
58+
AutoTokenizer,
59+
AutoModelForSequenceClassification,
60+
pipeline,
61+
)
62+
63+
logging.basicConfig(
64+
level=logging.INFO,
65+
format="%(asctime)s — %(levelname)s — %(message)s",
66+
)
67+
logger = logging.getLogger(__name__)
68+
69+
app = typer.Typer()
70+
71+
PROMPTS = ["<POS> Review:", "<NEG> Review:"]
72+
73+
74+
def _set_seed(seed: int | None):
75+
if seed is None:
76+
return
77+
random.seed(seed)
78+
torch.manual_seed(seed)
79+
if torch.cuda.is_available():
80+
torch.cuda.manual_seed_all(seed)
81+
82+
83+
def _clean_output(generated: str) -> str:
84+
"""Remove the sentiment prefix and tidy whitespace."""
85+
try:
86+
cleaned = generated.split(":", 1)[1].strip()
87+
except IndexError:
88+
cleaned = generated.strip()
89+
return " ".join(cleaned.split())
90+
91+
92+
def _write_jsonl(items: List[Dict[str, str]], path: Path):
93+
with path.open("w", encoding="utf-8") as f:
94+
for obj in items:
95+
json.dump(obj, f, ensure_ascii=False)
96+
f.write("\n")
97+
logger.info("Wrote %d samples → %s", len(items), path)
98+
99+
100+
def _label_from_teacher(res: Dict[str, str | float]) -> int | None:
101+
label = res["label"].lower()
102+
if label in {"positive", "pos", "label_1", "1"}:
103+
return 1
104+
if label in {"negative", "neg", "label_0", "0"}:
105+
return 0
106+
return None
107+
108+
109+
@app.command()
110+
def main(cfg_path: Path = typer.Argument(..., help="Path to YAML config")):
111+
# ------------------------------- Load config
112+
cfg = yaml.safe_load(cfg_path.read_text())
113+
114+
gen_ckpt = Path(cfg["model"]["ckpt_dir"]).expanduser()
115+
teacher_ckpt = Path(cfg["teacher"]["ckpt_dir"]).expanduser()
116+
min_conf = float(cfg["teacher"].get("min_confidence", 0.8))
117+
118+
n_target = int(cfg["data"].get("n_samples_total", 20000))
119+
split_ratio = cfg["data"].get(
120+
"split_ratio", {"train": 0.9, "val": 0.05, "test": 0.05}
121+
)
122+
if not math.isclose(sum(split_ratio.values()), 1.0, abs_tol=1e-6):
123+
raise ValueError("Split ratios must sum to 1.0")
124+
125+
gcfg = cfg.get("generation", {})
126+
_set_seed(gcfg.get("seed", 42))
127+
128+
# --------------------------------------- Generator pipeline
129+
device = 0 if torch.cuda.is_available() else -1
130+
tokenizer_gen = AutoTokenizer.from_pretrained(gen_ckpt)
131+
model_gen = (
132+
AutoModelForCausalLM.from_pretrained(gen_ckpt, torch_dtype=torch.float16)
133+
.to(device)
134+
.eval()
135+
)
136+
model_gen = torch.compile(model_gen, mode="reduce-overhead", fullgraph=False)
137+
138+
# Pre‑encode <POS>/<NEG> once
139+
prompt_ids = tokenizer_gen(
140+
PROMPTS, add_special_tokens=False, return_tensors="pt"
141+
).input_ids.to(device)
142+
143+
144+
# --------------------------------------- Teacher pipeline (CPU)
145+
tok = AutoTokenizer.from_pretrained(teacher_ckpt)
146+
torch.set_num_threads(30) # use all CPU cores
147+
model = AutoModelForSequenceClassification.from_pretrained(
148+
teacher_ckpt
149+
).eval() # CPU ⇢ device = -1
150+
teacher_pipe = pipeline(
151+
"text-classification",
152+
model=model,
153+
tokenizer=tok,
154+
device=-1, # ← run on CPU
155+
batch_size=1024,
156+
truncation=True,
157+
)
158+
159+
# Generation params
160+
gen_params = {
161+
"max_new_tokens": gcfg.get("max_new_tokens", 64),
162+
"temperature": gcfg.get("temperature", 0.7),
163+
"top_k": gcfg.get("top_k", 50),
164+
"top_p": gcfg.get("top_p", 0.95),
165+
"repetition_penalty": gcfg.get("repetition_penalty", 1.2),
166+
"eos_token_id": tokenizer_gen.eos_token_id,
167+
"do_sample": True,
168+
"num_return_sequences": 1,
169+
}
170+
batch_size = int(gcfg.get("batch_size", 128))
171+
172+
# --------------------------------------- Main generation loop
173+
dataset: List[Dict[str, str]] = []
174+
rejects = 0
175+
pbar = tqdm(total=n_target, desc="Accepted samples")
176+
177+
# ── Async teacher consumer on CPU ───────────────────────────────
178+
work_q: queue.Queue[list[str]] = queue.Queue(maxsize=4)
179+
res_q: queue.Queue[list[Dict]] = queue.Queue(maxsize=4)
180+
181+
def _cpu_teacher():
182+
while True:
183+
batch_txt = work_q.get()
184+
if batch_txt is None:
185+
break
186+
res_q.put(teacher_pipe(batch_txt))
187+
work_q.task_done()
188+
189+
t = threading.Thread(target=_cpu_teacher, daemon=True)
190+
t.start()
191+
192+
while len(dataset) < n_target:
193+
# 1️⃣ build input tensor (no re‑tokenisation)
194+
rep = (batch_size + len(PROMPTS) - 1) // len(PROMPTS)
195+
input_ids = prompt_ids.repeat_interleave(rep, 0)[:batch_size]
196+
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
197+
with torch.inference_mode():
198+
gen_ids = model_gen.generate(
199+
input_ids=input_ids,
200+
attention_mask=attention_mask, # ← add this
201+
pad_token_id=tokenizer_gen.pad_token_id,
202+
**gen_params,
203+
)
204+
outputs = tokenizer_gen.batch_decode(gen_ids, skip_special_tokens=True)
205+
206+
texts = [_clean_output(t) for t in outputs]
207+
texts = [t for t in texts if t]
208+
# 2️⃣ hand off to CPU thread and immediately start next gen loop
209+
work_q.put(texts)
210+
211+
# 3️⃣ drain any finished label batches
212+
while not res_q.empty() and len(dataset) < n_target:
213+
teacher_out = res_q.get()
214+
for res, text in zip(teacher_out, texts):
215+
conf = res["score"]
216+
label = _label_from_teacher(res)
217+
if label is not None and conf >= min_conf:
218+
dataset.append({"label": label, "text": text})
219+
pbar.update(1)
220+
if len(dataset) >= n_target:
221+
break
222+
else:
223+
rejects += 1
224+
logger.debug(
225+
"Rejected (conf=%.3f, label=%s): %.60s",
226+
conf,
227+
label,
228+
text,
229+
)
230+
231+
work_q.put(None) # stop CPU thread
232+
t.join()
233+
234+
logger.info(
235+
"Finished generation: %d accepted, %d rejected (%.2f%% rejection)",
236+
len(dataset),
237+
rejects,
238+
100 * rejects / (len(dataset) + rejects),
239+
)
240+
241+
# --------------------------------------- Shuffle & split
242+
random.shuffle(dataset)
243+
n_train = int(n_target * split_ratio["train"])
244+
n_val = int(n_target * split_ratio["val"])
245+
splits = {
246+
"train": dataset[:n_train],
247+
"val": dataset[n_train : n_train + n_val],
248+
"test": dataset[n_train + n_val :],
249+
}
250+
251+
# --------------------------------------- Write files
252+
out_dir = Path(cfg["data"]["output_dir"]).expanduser()
253+
out_dir.mkdir(parents=True, exist_ok=True)
254+
for split_name, items in splits.items():
255+
_write_jsonl(items, out_dir / f"{split_name}.jsonl")
256+
257+
logger.info("\u2705 Synthetic dataset ready → %s", out_dir)
258+
259+
260+
if __name__ == "__main__":
261+
app()

0 commit comments

Comments
 (0)