feat: add Connect Four LLM training pipeline (SFT + GRPO)#2
feat: add Connect Four LLM training pipeline (SFT + GRPO)#2
Conversation
Clean slate for a new solver implementation. https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
Implements a high-performance negamax solver with alpha-beta pruning and transposition table for exact game-theoretic scoring of Connect4 positions. Includes systematic and random position generators that output normalized CSV datasets for LLM fine-tuning. https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
The previous direction_wins helper was finding existing 4-in-a-rows instead of the empty gap cells where a piece would complete one. The correct approach enumerates all 4 gap patterns (XXX_, _XXX, XX_X, X_XX) per direction. Vertical only needs one pattern due to gravity. Also consolidates winning position logic into Board, removing the duplicate implementation from Solver. ~30x speedup on typical positions. https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
Four bugs fixed: 1. Transposition table now stores bound type (Exact/Lower/Upper) instead of treating all values the same. The old code could use an upper bound for beta cutoff or raise alpha with a lower bound incorrectly, producing wrong scores. 2. Node count now accumulates across all column solves in rank_moves instead of being reset by each solve() call (was showing "1"). 3. Terminal loss score formula fixed: -(43-(moves-1))/2 instead of -(43-moves)/2 to match the win detection formula. This off-by-one caused self-consistency failures on terminal positions. 4. CLI now detects and correctly reports game-over positions instead of silently showing "DRAW" for won games. https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
…ering Major solver optimizations: - Iterative deepening with null-window search: binary-searches for exact score using O(log(score_range)) null-window calls instead of one wide window. Each call maximizes alpha-beta cutoffs. - Non-losing moves filter: bitwise exclusion of moves that play directly below an opponent winning cell, pruning 30-50% of the tree. - Dynamic move ordering: scores each candidate move by the number of winning positions it creates, explores best moves first. - TT increased to 24M entries for better cache behavior. - Lower/upper bound initialization from TT in solve() for tighter initial window. Performance: 15-move positions solve in ~250ms (was 290ms). Positions with 8+ moves are fast enough for training data generation. Early game (<6 moves) still needs more advanced techniques (anticipation). Also fixes: bench command uses ROWS constant instead of hardcoded 6. https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
Generator improvements: - Rayon parallelism: positions are generated sequentially (cheap) then solved in parallel using thread-local Solvers. Each rayon thread gets its own TT for maximum cache efficiency. - Cross-batch dedup: shared HashSet persists across batch boundaries, preventing duplicate positions between CSV files. - Cross-phase dedup: generate_systematic returns its seen set, which is passed to the random phase in generate_full. - Min solve depth: systematic generator skips solving positions with fewer than 8 moves (too expensive), while still enumerating them for DFS traversal to reach deeper positions. - generate_systematic/generate_random no longer take &mut Solver (each thread creates its own via thread_local). https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
Three key optimizations: 1. Anticipation: after computing non-losing moves, check if each candidate allows the opponent to create a double threat (2+ winning positions on playable cells with a single move). Prune such moves before searching. 2. TT best-move ordering: store the best move (column) in TT entries and use it to explore the most promising move first. The TT move from previous null-window iterations is almost always the best move, giving much earlier alpha-beta cutoffs. 3. Packed TT entries: value (6 bits) + bound (2 bits) + best_move (3 bits) packed into u16, reducing per-entry memory from 18 bytes to 10 bytes. 4. O(1) legal_moves_mask via bottom-mask trick: (mask + BOTTOM_MASK) & BOARD_MASK replaces the 7-iteration column loop. Combined effect: ~21% node reduction on depth-6 positions (44M → 35M). Early game (<4 moves) remains slow — would need deeper anticipation or opening book to match Pons' reference solver. https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
Critical fix + major performance improvement based on Pons' reference:
1. Fix: non_losing_moves() now masks compute_winning_positions output
with (BOARD_MASK & !mask) to restrict to empty board cells. Previously
phantom bits in sentinel rows leaked through the >> 1 shift, causing
valid moves to be incorrectly excluded. Also moved double-threat
detection (forced & (forced-1)) into non_losing_moves itself.
2. Compact TT matching Pons' design:
- u32 partial keys + u8 values = 5 bytes/entry (was 10 bytes)
- Pons encoding: upper bounds as alpha-MIN+1 (1..37), lower bounds
as score+MAX-2*MIN+2 (38..74), 0 = empty. No separate bound flag.
- 8.4M entries = 42MB (fits near 33MB L3 cache)
- Result: 56% faster node throughput (6.7M vs 4.3M nodes/sec)
3. Removed can_create_double_threat anticipation from negamax (Pons
doesn't use it — overhead exceeded benefit).
4. Matched Pons' column order: {3,4,2,5,1,6,0}.
Net: depth-6 position 9.9s → 5.1s (48% faster).
https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
Exposes min_solve_depth as a CLI parameter so users can control which depths get solved. Default is 0 (solve all positions, including early game). On slower machines, use --min-depth 8 to skip expensive early positions. Usage: generate systematic --max-depth 10 --min-depth 0 # solve everything generate full --min-depth 8 # skip depths 0-7 generate full # default: solve all https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
1. Cargo.toml: use caret requirements ("1" instead of "1.3"/"1.10")
for csv and rayon dependencies.
2. board.rs: undo() no longer takes a col parameter — uses
move_history.pop() to determine which column, preventing
mismatched undo calls.
3. board.rs: test_is_full now tests both empty and full boards.
4. board.rs: test_non_losing_moves now verifies specific forced
blocking behavior (col 2 must be included when P2 threatens
horizontal completion there).
5. main.rs: terminal position message clarified — "current player
LOST" instead of ambiguous "player who just moved WON" + negative
score.
https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
Rewrite solver in Rust with bitboard engine and training data generator
There was a problem hiding this comment.
Code Review
This pull request introduces a training pipeline for Connect Four LLM agents using SFT and GRPO via the Unsloth library, including a comprehensive RunPod setup guide and evaluation scripts. The review feedback identifies critical improvements regarding data leakage and reproducibility in the dataset splitting logic, as well as potential runtime errors due to division by zero in the evaluation statistics. Additionally, recommendations were made to refactor global state into a more encapsulated class structure and to clarify command-line examples in the documentation for better usability.
| if not checkpoint_dir: print("ERROR: No model found"); return | ||
| model, tokenizer = FastLanguageModel.from_pretrained(model_name=checkpoint_dir, max_seq_length=config["max_seq_length"], load_in_4bit=config["load_in_4bit"]) | ||
| FastLanguageModel.for_inference(model) | ||
| raw_data = load_csv_data(config["csv_path"]); random.seed(42); random.shuffle(raw_data); eval_data = raw_data[-10_000:] |
There was a problem hiding this comment.
There are a couple of issues with how the training and evaluation datasets are split, which can lead to data leakage and lack of reproducibility:
- Data Leakage: The training data is taken from the start of the shuffled list (
data[:max_rows]) inprepare_sft_datasetandprepare_grpo_dataset, while the evaluation data is taken from the end (raw_data[-10_000:]). Ifsft_max_rowsorgrpo_max_rowsis large enough to overlap with the last 10,000 samples, the model will be evaluated on data it has already seen during training. - Reproducibility:
random.shuffle(data)is called inprepare_sft_datasetandprepare_grpo_datasetwithout setting a random seed. This means the training set will be different on each run, making experiments difficult to reproduce. Whilerun_evalsets a seed, the training data preparation does not.
A better approach is to perform a single, deterministic train/test split at the beginning of the script or in main.
Here's a suggested refactoring:
# In main()
raw_data = load_csv_data(config["csv_path"])
random.seed(42)
random.shuffle(raw_data)
eval_data = raw_data[-10_000:]
train_data = raw_data[:-10_000]
# Then pass `train_data` to run_sft/run_grpo and `eval_data` to run_eval
# and remove data loading and splitting from within those functions.
# Also remove `random.shuffle` from the `prepare_*_dataset` functions.| print(f" Valid: {valid}/{total} ({100*valid/total:.1f}%)") | ||
| print(f" Exact match: {exact}/{valid} ({100*exact/valid:.1f}%)") | ||
| print(f" Top-2 match: {top2}/{valid} ({100*top2/valid:.1f}%)") | ||
| print(f" Mean oracle score: {score_sum/valid:+.4f}") | ||
| for phase, s in phase_stats.items(): | ||
| if s["total"]>0: print(f" {phase}: {100*s['correct']/s['total']:.1f}% exact, {s['score_sum']/s['total']:+.3f} avg (n={s['total']})") | ||
| results = {"model":config["model_name"],"model_size":config["model_size"],"exact_match_pct":round(100*exact/valid,2),"top2_match_pct":round(100*top2/valid,2),"mean_oracle_score":round(score_sum/valid,4)} | ||
| with open(f"eval_results_{config['model_size']}.json","w") as f: json.dump(results,f,indent=2) |
| python connect4_train.py --model {4b,8b,14b} --stage export --csv connect4_data.csv | ||
| ``` |
There was a problem hiding this comment.
The placeholder {4b,8b,14b} in the command could be slightly confusing for users. It's clearer to provide a concrete example to show how to replace the placeholder.
| python connect4_train.py --model {4b,8b,14b} --stage export --csv connect4_data.csv | |
| ``` | |
| python connect4_train.py --model 8b --stage export --csv connect4_data.csv |
| formatted.append({"prompt": tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=True), "move_sequence": entry["move_sequence"]}) | ||
| return Dataset.from_list(formatted) | ||
|
|
||
| SCORE_LOOKUP = {} |
There was a problem hiding this comment.
Using a global variable SCORE_LOOKUP that is modified in run_grpo and read by the reward functions makes the code harder to reason about and test. This creates a hidden dependency between these functions.
A better approach would be to encapsulate this state. For example, you could use a class to hold the lookup table and define the reward functions as its methods. This would make the data flow explicit.
Example refactoring:
class RewardCalculator:
def __init__(self, data):
self.score_lookup = build_lookup_table(data)
def reward_move_quality(self, completions, move_sequence, **kwargs):
# ... use self.score_lookup
...
# In run_grpo:
raw_data = load_csv_data(config["csv_path"])
reward_calculator = RewardCalculator(raw_data)
trainer = GRPOTrainer(
...
reward_funcs=[
reward_calculator.reward_format,
reward_calculator.reward_move_quality,
reward_calculator.reward_is_best_move
],
...
)- New --stage push uploads merged 16-bit + GGUF to HF Hub - --hf-repo flag to specify target repo (creates -GGUF sibling repo) - Updated RunPod guide with Step 6 and huggingface-cli login https://claude.ai/code/session_015dRfJUv1tpJrJ4AVgqLaoL
…wards - Single deterministic train/eval split in main() prevents data leakage - Remove random.shuffle from prepare_*_dataset (split is done upfront) - Add ZeroDivisionError guards for total==0 and valid==0 in eval - Replace global SCORE_LOOKUP with RewardCalculator class - Use concrete example in RUNPOD_SETUP.md export command https://claude.ai/code/session_015dRfJUv1tpJrJ4AVgqLaoL
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a Connect Four LLM training pipeline utilizing SFT and GRPO with the Unsloth library, complemented by RunPod setup documentation and environment configuration. The reviewer suggested several improvements, including exposing more training parameters as command-line arguments, adhering to PEP 8 standards for imports and assignments, removing unused imports, and avoiding the use of global variables to improve code maintainability.
|
|
||
| - Lower temperature (try 0.5 instead of 0.7) — the model may be generating too randomly. | ||
| - Increase `grpo_grad_accum` to 8 or 16 for more stable gradients. | ||
| - Try switching to `loss_type="dr_grpo"` in the GRPOConfig (edit the script). |
| python connect4_train.py --model 8b --stage push --hf-repo yourname/connect4-agent-8b | ||
| """ | ||
|
|
||
| import argparse, csv, json, os, re, random |
| """ | ||
|
|
||
| import argparse, csv, json, os, re, random | ||
| from collections import defaultdict |
| parser.add_argument("--hf-repo", default=None, help="HuggingFace repo id for push (e.g. yourname/connect4-agent-8b)") | ||
| parser.add_argument("--no-wandb", action="store_true") | ||
| args = parser.parse_args() | ||
| if args.no_wandb: global WANDB_AVAILABLE; WANDB_AVAILABLE = False |
There was a problem hiding this comment.
Modifying a global variable (WANDB_AVAILABLE) directly within a function is generally discouraged as it can lead to unexpected side effects and make code harder to reason about. Consider passing WANDB_AVAILABLE as an argument to main or encapsulating this logic within a class.
| if args.no_wandb: global WANDB_AVAILABLE; WANDB_AVAILABLE = False | |
| if args.no_wandb: | |
| WANDB_AVAILABLE = False |
| parser.add_argument("--no-wandb", action="store_true") | ||
| args = parser.parse_args() | ||
| if args.no_wandb: global WANDB_AVAILABLE; WANDB_AVAILABLE = False | ||
| config = get_config(args.model); config["csv_path"] = args.csv; config["hf_repo"] = args.hf_repo |
There was a problem hiding this comment.
Multiple assignments on a single line can sometimes reduce readability. For better clarity and adherence to PEP 8, consider splitting this into separate lines.
| config = get_config(args.model); config["csv_path"] = args.csv; config["hf_repo"] = args.hf_repo | |
| config = get_config(args.model) | |
| config["csv_path"] = args.csv | |
| config["hf_repo"] = args.hf_repo |
- PEP 8: one import per line, remove unused defaultdict import - Move global WANDB_AVAILABLE mutation to separate lines - Split chained assignments in main() for readability - Add --loss-type CLI flag for GRPO (e.g. --loss-type dr_grpo) - Update RUNPOD_SETUP.md troubleshooting to reference CLI flag https://claude.ai/code/session_015dRfJUv1tpJrJ4AVgqLaoL
…inking - Reconstruct board from move_sequence (mirrors game.py exactly) - Use production SYSTEM_PROMPT and USER_TEMPLATE from ai.py - Drop SFT stage — Qwen3 already knows JSON and reasons natively - GRPO-only with extended thinking: max_completion_length=3072 - JSON-aware output parsing (handles <think> blocks) - Tiered format reward: +0.5 JSON, +0.25 regex, -1.0 invalid - Eval now reports JSON format % alongside accuracy metrics - Reduced num_generations (4/3/3) to fit longer completions in VRAM https://claude.ai/code/session_015dRfJUv1tpJrJ4AVgqLaoL
|
/Gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive training and evaluation pipeline for Connect Four agents using GRPO and Unsloth, accompanied by a detailed RunPod setup guide and updated ignore rules. The implementation covers board reconstruction, reward calculation, and model export functionality. The review feedback provides several actionable improvements to enhance the robustness and maintainability of the code, including better CSV file handling, ensuring deterministic evaluation through greedy decoding, preventing redundant repository naming suffixes during HuggingFace uploads, and refactoring global state management to avoid side effects.
|
|
||
| def load_csv_data(csv_path): | ||
| data = [] | ||
| with open(csv_path, "r") as f: |
There was a problem hiding this comment.
For improved robustness and to prevent platform-specific issues, it's best practice to specify encoding="utf-8" and newline="" when opening CSV files with Python's csv module. This ensures consistent behavior across different operating systems and correctly handles newlines within quoted fields.
| with open(csv_path, "r") as f: | |
| with open(csv_path, "r", encoding="utf-8", newline="") as f: |
| ).to(model.device) | ||
|
|
||
| with torch.no_grad(): | ||
| outputs = model.generate(input_ids=inputs, max_new_tokens=3072, temperature=0.01, do_sample=True) |
There was a problem hiding this comment.
The evaluation process is currently stochastic due to do_sample=True. For reproducible and deterministic evaluation results, it's recommended to use greedy decoding. This can be achieved by setting do_sample=False. The temperature parameter can then be removed as it has no effect with greedy decoding.
| outputs = model.generate(input_ids=inputs, max_new_tokens=3072, temperature=0.01, do_sample=True) | |
| outputs = model.generate(input_ids=inputs, max_new_tokens=3072, do_sample=False) |
| print("ERROR: huggingface_hub not installed. Run: pip install huggingface_hub") | ||
| return | ||
| from huggingface_hub import HfApi, create_repo | ||
| hf_repo = config.get("hf_repo") |
There was a problem hiding this comment.
The push_to_hub function constructs the GGUF repository name by appending -GGUF to the provided --hf-repo argument. If a user mistakenly provides a repo name that already ends with -GGUF, it will result in a doubly-appended name (e.g., my-repo-GGUF-GGUF). It would be more robust to add a check for this.
You could add a check after getting hf_repo from the config:
hf_repo = config.get("hf_repo")
if hf_repo and hf_repo.upper().endswith("-GGUF"):
print("WARNING: --hf-repo should be the base repo name. Stripping '-GGUF' suffix.")
hf_repo = hf_repo[:-5]| if args.no_wandb: | ||
| global WANDB_AVAILABLE | ||
| WANDB_AVAILABLE = False |
There was a problem hiding this comment.
Modifying a global variable (WANDB_AVAILABLE) within the main function can make the code harder to understand and maintain. A cleaner approach is to determine the desired wandb state within main and pass it as an argument to the functions that need it. This avoids side effects and makes data flow explicit.
Consider this alternative approach:
-
In
main(), remove the global modification and instead create a local variable and pass it via theconfigdictionary:# In main() use_wandb = WANDB_AVAILABLE and not args.no_wandb config['use_wandb'] = use_wandb
-
Update
run_grpoto use this config value:# In run_grpo() use_wandb = config.get('use_wandb', WANDB_AVAILABLE) # Fallback to global default if use_wandb: wandb.init(...) grpo_kwargs = dict( ... report_to="wandb" if use_wandb else "none", ... )
This makes the dependency on wandb explicit.
- Add encoding="utf-8" and newline="" to CSV open for cross-platform safety - Use do_sample=False in eval for deterministic, reproducible results - Guard against --hf-repo ending in -GGUF (strip suffix with warning) - Pass wandb state via config["use_wandb"] instead of mutating global https://claude.ai/code/session_015dRfJUv1tpJrJ4AVgqLaoL
https://claude.ai/code/session_015dRfJUv1tpJrJ4AVgqLaoL