Skip to content

feat: add Connect Four LLM training pipeline (SFT + GRPO)#2

Open
FloLey wants to merge 17 commits intomasterfrom
claude/add-connect4-training-1zeh5
Open

feat: add Connect Four LLM training pipeline (SFT + GRPO)#2
FloLey wants to merge 17 commits intomasterfrom
claude/add-connect4-training-1zeh5

Conversation

@FloLey
Copy link
Copy Markdown
Owner

@FloLey FloLey commented Mar 26, 2026

claude and others added 12 commits March 26, 2026 11:02
Implements a high-performance negamax solver with alpha-beta pruning
and transposition table for exact game-theoretic scoring of Connect4
positions. Includes systematic and random position generators that
output normalized CSV datasets for LLM fine-tuning.

https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
The previous direction_wins helper was finding existing 4-in-a-rows
instead of the empty gap cells where a piece would complete one.
The correct approach enumerates all 4 gap patterns (XXX_, _XXX, XX_X,
X_XX) per direction. Vertical only needs one pattern due to gravity.

Also consolidates winning position logic into Board, removing the
duplicate implementation from Solver. ~30x speedup on typical positions.

https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
Four bugs fixed:

1. Transposition table now stores bound type (Exact/Lower/Upper)
   instead of treating all values the same. The old code could use
   an upper bound for beta cutoff or raise alpha with a lower bound
   incorrectly, producing wrong scores.

2. Node count now accumulates across all column solves in rank_moves
   instead of being reset by each solve() call (was showing "1").

3. Terminal loss score formula fixed: -(43-(moves-1))/2 instead of
   -(43-moves)/2 to match the win detection formula. This off-by-one
   caused self-consistency failures on terminal positions.

4. CLI now detects and correctly reports game-over positions instead
   of silently showing "DRAW" for won games.

https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
…ering

Major solver optimizations:
- Iterative deepening with null-window search: binary-searches for exact
  score using O(log(score_range)) null-window calls instead of one wide
  window. Each call maximizes alpha-beta cutoffs.
- Non-losing moves filter: bitwise exclusion of moves that play directly
  below an opponent winning cell, pruning 30-50% of the tree.
- Dynamic move ordering: scores each candidate move by the number of
  winning positions it creates, explores best moves first.
- TT increased to 24M entries for better cache behavior.
- Lower/upper bound initialization from TT in solve() for tighter
  initial window.

Performance: 15-move positions solve in ~250ms (was 290ms). Positions
with 8+ moves are fast enough for training data generation. Early game
(<6 moves) still needs more advanced techniques (anticipation).

Also fixes: bench command uses ROWS constant instead of hardcoded 6.

https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
Generator improvements:
- Rayon parallelism: positions are generated sequentially (cheap) then
  solved in parallel using thread-local Solvers. Each rayon thread gets
  its own TT for maximum cache efficiency.
- Cross-batch dedup: shared HashSet persists across batch boundaries,
  preventing duplicate positions between CSV files.
- Cross-phase dedup: generate_systematic returns its seen set, which is
  passed to the random phase in generate_full.
- Min solve depth: systematic generator skips solving positions with
  fewer than 8 moves (too expensive), while still enumerating them for
  DFS traversal to reach deeper positions.
- generate_systematic/generate_random no longer take &mut Solver
  (each thread creates its own via thread_local).

https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
Three key optimizations:

1. Anticipation: after computing non-losing moves, check if each candidate
   allows the opponent to create a double threat (2+ winning positions on
   playable cells with a single move). Prune such moves before searching.

2. TT best-move ordering: store the best move (column) in TT entries and
   use it to explore the most promising move first. The TT move from
   previous null-window iterations is almost always the best move, giving
   much earlier alpha-beta cutoffs.

3. Packed TT entries: value (6 bits) + bound (2 bits) + best_move (3 bits)
   packed into u16, reducing per-entry memory from 18 bytes to 10 bytes.

4. O(1) legal_moves_mask via bottom-mask trick: (mask + BOTTOM_MASK) &
   BOARD_MASK replaces the 7-iteration column loop.

Combined effect: ~21% node reduction on depth-6 positions (44M → 35M).
Early game (<4 moves) remains slow — would need deeper anticipation or
opening book to match Pons' reference solver.

https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
Critical fix + major performance improvement based on Pons' reference:

1. Fix: non_losing_moves() now masks compute_winning_positions output
   with (BOARD_MASK & !mask) to restrict to empty board cells. Previously
   phantom bits in sentinel rows leaked through the >> 1 shift, causing
   valid moves to be incorrectly excluded. Also moved double-threat
   detection (forced & (forced-1)) into non_losing_moves itself.

2. Compact TT matching Pons' design:
   - u32 partial keys + u8 values = 5 bytes/entry (was 10 bytes)
   - Pons encoding: upper bounds as alpha-MIN+1 (1..37), lower bounds
     as score+MAX-2*MIN+2 (38..74), 0 = empty. No separate bound flag.
   - 8.4M entries = 42MB (fits near 33MB L3 cache)
   - Result: 56% faster node throughput (6.7M vs 4.3M nodes/sec)

3. Removed can_create_double_threat anticipation from negamax (Pons
   doesn't use it — overhead exceeded benefit).

4. Matched Pons' column order: {3,4,2,5,1,6,0}.

Net: depth-6 position 9.9s → 5.1s (48% faster).

https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
Exposes min_solve_depth as a CLI parameter so users can control which
depths get solved. Default is 0 (solve all positions, including early
game). On slower machines, use --min-depth 8 to skip expensive early
positions.

Usage:
  generate systematic --max-depth 10 --min-depth 0   # solve everything
  generate full --min-depth 8                         # skip depths 0-7
  generate full                                       # default: solve all

https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
1. Cargo.toml: use caret requirements ("1" instead of "1.3"/"1.10")
   for csv and rayon dependencies.

2. board.rs: undo() no longer takes a col parameter — uses
   move_history.pop() to determine which column, preventing
   mismatched undo calls.

3. board.rs: test_is_full now tests both empty and full boards.

4. board.rs: test_non_losing_moves now verifies specific forced
   blocking behavior (col 2 must be included when P2 threatens
   horizontal completion there).

5. main.rs: terminal position message clarified — "current player
   LOST" instead of ambiguous "player who just moved WON" + negative
   score.

https://claude.ai/code/session_01Y8peoctJpu7uW3znbX7WMM
Rewrite solver in Rust with bitboard engine and training data generator
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a training pipeline for Connect Four LLM agents using SFT and GRPO via the Unsloth library, including a comprehensive RunPod setup guide and evaluation scripts. The review feedback identifies critical improvements regarding data leakage and reproducibility in the dataset splitting logic, as well as potential runtime errors due to division by zero in the evaluation statistics. Additionally, recommendations were made to refactor global state into a more encapsulated class structure and to clarify command-line examples in the documentation for better usability.

Comment thread training/connect4_train.py Outdated
if not checkpoint_dir: print("ERROR: No model found"); return
model, tokenizer = FastLanguageModel.from_pretrained(model_name=checkpoint_dir, max_seq_length=config["max_seq_length"], load_in_4bit=config["load_in_4bit"])
FastLanguageModel.for_inference(model)
raw_data = load_csv_data(config["csv_path"]); random.seed(42); random.shuffle(raw_data); eval_data = raw_data[-10_000:]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are a couple of issues with how the training and evaluation datasets are split, which can lead to data leakage and lack of reproducibility:

  1. Data Leakage: The training data is taken from the start of the shuffled list (data[:max_rows]) in prepare_sft_dataset and prepare_grpo_dataset, while the evaluation data is taken from the end (raw_data[-10_000:]). If sft_max_rows or grpo_max_rows is large enough to overlap with the last 10,000 samples, the model will be evaluated on data it has already seen during training.
  2. Reproducibility: random.shuffle(data) is called in prepare_sft_dataset and prepare_grpo_dataset without setting a random seed. This means the training set will be different on each run, making experiments difficult to reproduce. While run_eval sets a seed, the training data preparation does not.

A better approach is to perform a single, deterministic train/test split at the beginning of the script or in main.

Here's a suggested refactoring:

# In main()
raw_data = load_csv_data(config["csv_path"])
random.seed(42)
random.shuffle(raw_data)
eval_data = raw_data[-10_000:]
train_data = raw_data[:-10_000]

# Then pass `train_data` to run_sft/run_grpo and `eval_data` to run_eval
# and remove data loading and splitting from within those functions.
# Also remove `random.shuffle` from the `prepare_*_dataset` functions.

Comment thread training/connect4_train.py Outdated
Comment on lines +190 to +197
print(f" Valid: {valid}/{total} ({100*valid/total:.1f}%)")
print(f" Exact match: {exact}/{valid} ({100*exact/valid:.1f}%)")
print(f" Top-2 match: {top2}/{valid} ({100*top2/valid:.1f}%)")
print(f" Mean oracle score: {score_sum/valid:+.4f}")
for phase, s in phase_stats.items():
if s["total"]>0: print(f" {phase}: {100*s['correct']/s['total']:.1f}% exact, {s['score_sum']/s['total']:+.3f} avg (n={s['total']})")
results = {"model":config["model_name"],"model_size":config["model_size"],"exact_match_pct":round(100*exact/valid,2),"top2_match_pct":round(100*top2/valid,2),"mean_oracle_score":round(score_sum/valid,4)}
with open(f"eval_results_{config['model_size']}.json","w") as f: json.dump(results,f,indent=2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If total or valid is 0, this block of code will raise a ZeroDivisionError when calculating and printing statistics. You should add checks to handle these cases before performing division.

Comment thread training/RUNPOD_SETUP.md Outdated
Comment on lines +122 to +123
python connect4_train.py --model {4b,8b,14b} --stage export --csv connect4_data.csv
```
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The placeholder {4b,8b,14b} in the command could be slightly confusing for users. It's clearer to provide a concrete example to show how to replace the placeholder.

Suggested change
python connect4_train.py --model {4b,8b,14b} --stage export --csv connect4_data.csv
```
python connect4_train.py --model 8b --stage export --csv connect4_data.csv

Comment thread training/connect4_train.py Outdated
formatted.append({"prompt": tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=True), "move_sequence": entry["move_sequence"]})
return Dataset.from_list(formatted)

SCORE_LOOKUP = {}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a global variable SCORE_LOOKUP that is modified in run_grpo and read by the reward functions makes the code harder to reason about and test. This creates a hidden dependency between these functions.

A better approach would be to encapsulate this state. For example, you could use a class to hold the lookup table and define the reward functions as its methods. This would make the data flow explicit.

Example refactoring:

class RewardCalculator:
    def __init__(self, data):
        self.score_lookup = build_lookup_table(data)

    def reward_move_quality(self, completions, move_sequence, **kwargs):
        # ... use self.score_lookup
        ...

# In run_grpo:
raw_data = load_csv_data(config["csv_path"])
reward_calculator = RewardCalculator(raw_data)
trainer = GRPOTrainer(
    ...
    reward_funcs=[
        reward_calculator.reward_format, 
        reward_calculator.reward_move_quality, 
        reward_calculator.reward_is_best_move
    ],
    ...
)

claude added 2 commits March 26, 2026 21:21
- New --stage push uploads merged 16-bit + GGUF to HF Hub
- --hf-repo flag to specify target repo (creates -GGUF sibling repo)
- Updated RunPod guide with Step 6 and huggingface-cli login

https://claude.ai/code/session_015dRfJUv1tpJrJ4AVgqLaoL
…wards

- Single deterministic train/eval split in main() prevents data leakage
- Remove random.shuffle from prepare_*_dataset (split is done upfront)
- Add ZeroDivisionError guards for total==0 and valid==0 in eval
- Replace global SCORE_LOOKUP with RewardCalculator class
- Use concrete example in RUNPOD_SETUP.md export command

https://claude.ai/code/session_015dRfJUv1tpJrJ4AVgqLaoL
@FloLey
Copy link
Copy Markdown
Owner Author

FloLey commented Mar 26, 2026

/gemini review

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Connect Four LLM training pipeline utilizing SFT and GRPO with the Unsloth library, complemented by RunPod setup documentation and environment configuration. The reviewer suggested several improvements, including exposing more training parameters as command-line arguments, adhering to PEP 8 standards for imports and assignments, removing unused imports, and avoiding the use of global variables to improve code maintainability.

Comment thread training/RUNPOD_SETUP.md Outdated

- Lower temperature (try 0.5 instead of 0.7) — the model may be generating too randomly.
- Increase `grpo_grad_accum` to 8 or 16 for more stable gradients.
- Try switching to `loss_type="dr_grpo"` in the GRPOConfig (edit the script).
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The troubleshooting tip suggests editing the script for loss_type="dr_grpo". It would be more user-friendly and consistent with other configurable parameters if this could also be passed as a command-line argument to connect4_train.py.

Comment thread training/connect4_train.py Outdated
python connect4_train.py --model 8b --stage push --hf-repo yourname/connect4-agent-8b
"""

import argparse, csv, json, os, re, random
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and adherence to PEP 8, it is recommended to have one import statement per line.

import argparse
import csv
import json
import os
import re
import random

Comment thread training/connect4_train.py Outdated
"""

import argparse, csv, json, os, re, random
from collections import defaultdict
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The defaultdict import is not used in this script. Please remove unused imports to keep the codebase clean.

Suggested change
from collections import defaultdict
import torch

Comment thread training/connect4_train.py Outdated
parser.add_argument("--hf-repo", default=None, help="HuggingFace repo id for push (e.g. yourname/connect4-agent-8b)")
parser.add_argument("--no-wandb", action="store_true")
args = parser.parse_args()
if args.no_wandb: global WANDB_AVAILABLE; WANDB_AVAILABLE = False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Modifying a global variable (WANDB_AVAILABLE) directly within a function is generally discouraged as it can lead to unexpected side effects and make code harder to reason about. Consider passing WANDB_AVAILABLE as an argument to main or encapsulating this logic within a class.

Suggested change
if args.no_wandb: global WANDB_AVAILABLE; WANDB_AVAILABLE = False
if args.no_wandb:
WANDB_AVAILABLE = False

Comment thread training/connect4_train.py Outdated
parser.add_argument("--no-wandb", action="store_true")
args = parser.parse_args()
if args.no_wandb: global WANDB_AVAILABLE; WANDB_AVAILABLE = False
config = get_config(args.model); config["csv_path"] = args.csv; config["hf_repo"] = args.hf_repo
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Multiple assignments on a single line can sometimes reduce readability. For better clarity and adherence to PEP 8, consider splitting this into separate lines.

Suggested change
config = get_config(args.model); config["csv_path"] = args.csv; config["hf_repo"] = args.hf_repo
config = get_config(args.model)
config["csv_path"] = args.csv
config["hf_repo"] = args.hf_repo

claude added 2 commits March 26, 2026 22:18
- PEP 8: one import per line, remove unused defaultdict import
- Move global WANDB_AVAILABLE mutation to separate lines
- Split chained assignments in main() for readability
- Add --loss-type CLI flag for GRPO (e.g. --loss-type dr_grpo)
- Update RUNPOD_SETUP.md troubleshooting to reference CLI flag

https://claude.ai/code/session_015dRfJUv1tpJrJ4AVgqLaoL
…inking

- Reconstruct board from move_sequence (mirrors game.py exactly)
- Use production SYSTEM_PROMPT and USER_TEMPLATE from ai.py
- Drop SFT stage — Qwen3 already knows JSON and reasons natively
- GRPO-only with extended thinking: max_completion_length=3072
- JSON-aware output parsing (handles <think> blocks)
- Tiered format reward: +0.5 JSON, +0.25 regex, -1.0 invalid
- Eval now reports JSON format % alongside accuracy metrics
- Reduced num_generations (4/3/3) to fit longer completions in VRAM

https://claude.ai/code/session_015dRfJUv1tpJrJ4AVgqLaoL
@FloLey
Copy link
Copy Markdown
Owner Author

FloLey commented Mar 28, 2026

/Gemini review

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive training and evaluation pipeline for Connect Four agents using GRPO and Unsloth, accompanied by a detailed RunPod setup guide and updated ignore rules. The implementation covers board reconstruction, reward calculation, and model export functionality. The review feedback provides several actionable improvements to enhance the robustness and maintainability of the code, including better CSV file handling, ensuring deterministic evaluation through greedy decoding, preventing redundant repository naming suffixes during HuggingFace uploads, and refactoring global state management to avoid side effects.

Comment thread training/connect4_train.py Outdated

def load_csv_data(csv_path):
data = []
with open(csv_path, "r") as f:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved robustness and to prevent platform-specific issues, it's best practice to specify encoding="utf-8" and newline="" when opening CSV files with Python's csv module. This ensures consistent behavior across different operating systems and correctly handles newlines within quoted fields.

Suggested change
with open(csv_path, "r") as f:
with open(csv_path, "r", encoding="utf-8", newline="") as f:

Comment thread training/connect4_train.py Outdated
).to(model.device)

with torch.no_grad():
outputs = model.generate(input_ids=inputs, max_new_tokens=3072, temperature=0.01, do_sample=True)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The evaluation process is currently stochastic due to do_sample=True. For reproducible and deterministic evaluation results, it's recommended to use greedy decoding. This can be achieved by setting do_sample=False. The temperature parameter can then be removed as it has no effect with greedy decoding.

Suggested change
outputs = model.generate(input_ids=inputs, max_new_tokens=3072, temperature=0.01, do_sample=True)
outputs = model.generate(input_ids=inputs, max_new_tokens=3072, do_sample=False)

print("ERROR: huggingface_hub not installed. Run: pip install huggingface_hub")
return
from huggingface_hub import HfApi, create_repo
hf_repo = config.get("hf_repo")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The push_to_hub function constructs the GGUF repository name by appending -GGUF to the provided --hf-repo argument. If a user mistakenly provides a repo name that already ends with -GGUF, it will result in a doubly-appended name (e.g., my-repo-GGUF-GGUF). It would be more robust to add a check for this.

You could add a check after getting hf_repo from the config:

    hf_repo = config.get("hf_repo")
    if hf_repo and hf_repo.upper().endswith("-GGUF"):
        print("WARNING: --hf-repo should be the base repo name. Stripping '-GGUF' suffix.")
        hf_repo = hf_repo[:-5]

Comment thread training/connect4_train.py Outdated
Comment on lines +567 to +569
if args.no_wandb:
global WANDB_AVAILABLE
WANDB_AVAILABLE = False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Modifying a global variable (WANDB_AVAILABLE) within the main function can make the code harder to understand and maintain. A cleaner approach is to determine the desired wandb state within main and pass it as an argument to the functions that need it. This avoids side effects and makes data flow explicit.

Consider this alternative approach:

  1. In main(), remove the global modification and instead create a local variable and pass it via the config dictionary:

    # In main()
    use_wandb = WANDB_AVAILABLE and not args.no_wandb
    config['use_wandb'] = use_wandb
  2. Update run_grpo to use this config value:

    # In run_grpo()
    use_wandb = config.get('use_wandb', WANDB_AVAILABLE) # Fallback to global default
    if use_wandb:
        wandb.init(...)
    
    grpo_kwargs = dict(
        ...
        report_to="wandb" if use_wandb else "none",
        ...
    )

This makes the dependency on wandb explicit.

- Add encoding="utf-8" and newline="" to CSV open for cross-platform safety
- Use do_sample=False in eval for deterministic, reproducible results
- Guard against --hf-repo ending in -GGUF (strip suffix with warning)
- Pass wandb state via config["use_wandb"] instead of mutating global

https://claude.ai/code/session_015dRfJUv1tpJrJ4AVgqLaoL
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants