diff --git a/DEVLOG.md b/DEVLOG.md index 0b47483..e84c51d 100644 --- a/DEVLOG.md +++ b/DEVLOG.md @@ -213,7 +213,7 @@ Reading it honestly: QA-only heuristic cannot cover header->question and other link types. `all_199` carries the "contains the 50 test + 149 tuned forms, not held-out" caveat in the report JSON. -Design (locked in discussion; see `plans/phase3-funsd-relations.md`): +Design (locked in discussion; see `docs/phase3_brief.md`): - **Predictor:** per-answer argmax + distance gate; distances normalized by the form's median entity height; two separate knobs (`max_distance_units` distance gate, `min_score` floor). - **GT links:** deduped to undirected frozensets (FUNSD records links bidirectionally), then diff --git a/docs/phase3_brief.md b/docs/phase3_brief.md new file mode 100644 index 0000000..300d415 --- /dev/null +++ b/docs/phase3_brief.md @@ -0,0 +1,141 @@ +# Phase 3 — FUNSD relation branch (V1 baseline) + +> Implementation brief for Phase 3. Committed in the repo (travels with `git pull` to Colab) +> so the references to it in `DEVLOG.md` and the `src/funsd_extraction.py` / +> `src/eval_funsd.py` module docstrings resolve. Status: V1 implemented and scored on +> `feature/phase3-funsd-relations`; headline held-out `test_50.qa_links` F1 0.727. + +## Context + +Phases 0-2 are merged to `main` (FinTabNet.c table topology + OCR content + table-only RAG + +DocLayNet layout crop). Phase 3 is a **FUNSD relation-linking baseline** over GT entities. + +It is a deliberately standalone branch — it does **not** touch the RAG pipeline and only +wires into the demo in Phase 4. It is **annotation-only and CPU-only**: the FUNSD +annotation JSON already carries each entity's text, bbox, label, and GT `linking` pairs, +so scoring, the spatial heuristic, and the gold links never load image pixels. No GPU, no +Colab — this is the local "logic loop", fast `pytest`. + +The task: given GT entities `{question, answer, header, other}`, predict which entities are +linked, scored P/R/F1 against GT `linking`. + +## Locked design decisions + +- **Predictor:** deterministic spatial heuristic, **per-answer argmax + distance gate**. + Each `answer` picks its single highest-scoring `question`; emit the link only if the best + score clears the gate. Distances normalized by the form's **median entity height** so one + gate works across differently-scaled scans. (Rejected: per-question argmax -> under-predicts + multi-answer questions; global threshold -> precision-fragile, needs a tuned cutoff.) +- **GT links:** dedupe `linking` to **undirected** `frozenset({id_a, id_b})` per form (FUNSD + records links bidirectionally/duplicated). Then derive scopes: + - `qa_links` (**primary**): pairs whose endpoints are one `question` + one `answer`, + canonicalized to directed `(question_id, answer_id)`. + - `all_links` (**secondary**): the full deduped undirected set, scored as frozensets. +- **Eval split / reporting matrix:** + - **Primary headline:** `test_50.qa_links.micro_f1` (official 50-form test split). + - Set/tune any heuristic params on **train_149 only**; never on the reported test set. + - **Secondary:** `all_199.qa_links`, `test_50.all_links`, `all_199.all_links`. Print + `all_199` with a "contains the 50 test + 149 tuned forms, not held-out" caveat. + - `debug_20` (first 20 train forms) is for parser/CLI smoke only, never for tuning. +- **all_links is a coverage diagnostic, not a second predictor.** V1 predicts only `q->a`; + the `all_links` row scores those same QA predictions (as undirected frozensets) against the + full GT link set, i.e. "what fraction of all GT links does the QA-only heuristic cover." +- **No sklearn in V1.** P/R/F1 is set arithmetic (`len(pred & gold)/len(pred)` etc.). sklearn + enters only if a fitted ranker is added later, and only fit on train_149. +- **Data:** raw FUNSD zip -> `data/raw/funsd/...` (gitignored). Tests use **synthetic fixtures + only**, never the raw dataset. + +## Files + +### `src/config.py` — FUNSD paths +```python +FUNSD_ROOT = DATA_ROOT / "raw" / "funsd" / "dataset" +FUNSD_TRAIN = FUNSD_ROOT / "training_data" / "annotations" # 149 forms +FUNSD_TEST = FUNSD_ROOT / "testing_data" / "annotations" # 50 forms +``` +Output reuses existing `config.EVALUATION`. + +### `src/funsd_extraction.py` — data contract + baseline predictor +Follows the `from __future__ import annotations` + TypedDict + pure-function style of +`src/canonical_schema.py` / `src/eval_retrieval.py`. +- `FunsdEntity` TypedDict: `id, label, text, box [x0,y0,x1,y1]`. +- `FunsdForm` TypedDict: `form_id, entities, gold_links (set[frozenset[int]])`. In-memory it + is a `set` (dedupe is its nature); cast to `list` only when serializing JSON output. +- `parse_funsd_form(data, form_id)` / `parse_funsd_json(path)`: normalize entities, collect + `linking`, dedupe to undirected frozensets, drop self-links and links to missing ids. +- `load_funsd_split(annotations_dir)`. +- `qa_gold_links(form) -> set[tuple[int,int]]`: question+answer pairs -> directed `(q,a)`. +- `all_gold_links(form) -> set[frozenset]`: the deduped undirected set. +- `HeuristicParams` (frozen dataclass, a-priori defaults; the documented tunable surface, + train_149 only). Two clearly-separated knobs, not one fuzzy "gate": + - `max_distance_units`: **distance gate** — a (Q, A) candidate is rejected if its + median-height-normalized distance exceeds this. Filters the candidate set. + - `min_score`: **score threshold** — the per-answer argmax winner is emitted only if its + final score clears this. Acceptance test on the chosen link. + - plus right-band tolerance, below-gap tolerance, and the boost weights. +- `predict_qa_links(form, params=HeuristicParams()) -> set[tuple[int,int]]`: per-answer argmax. + For each answer A, score **every** question candidate (distances scaled by + `median_entity_height(entities)`), drop candidates beyond `max_distance_units`, take the + highest-scoring question, and emit `(q,a)` only if that score >= `min_score`. Geometry: + - **same-row right-side** (A vertically within Q's band, A to Q's right): strongest score. + - **below** (A under Q, horizontally aligned/overlapping): fallback score. + - proximity + alignment are additive boosts. (below and right-side compete in the same + argmax — a below candidate wins whenever it is the best valid candidate.) + +### `src/eval_funsd.py` — custom set-based metrics +Split so the metric stays pure (no predictor/params dependency) and the form-runner is separate: +- `prf1(pred, gold) -> dict`: one pred/gold set -> tp/precision/recall/f1, zero-guards. +- `evaluate_pairs(per_form) -> dict`: **pure** micro P/R/F1 over prebuilt (pred, gold) pairs + + counts. No predictor inside, so it is trivially unit-tested with synthetic sets. +- `evaluate_forms(forms, scope, params=HeuristicParams()) -> dict`: builds per-form (pred, gold) + for `scope in {"qa","all"}` — qa = directed `(q,a)` tuples from `predict_qa_links` vs + `qa_gold_links`; all = those predictions cast to undirected frozensets vs `all_gold_links` — + then delegates to `evaluate_pairs`. + +### `scripts/evaluate_funsd.py` — CLI runner +Mirrors `scripts/evaluate_rag.py`. Loads `train_149` / `test_50`, builds `all_199` and +`debug_20`, runs split x scope, writes `config.EVALUATION / "phase3_funsd_relations.json"`, +prints the headline `test_50.qa_links` + secondaries. Guards with a `SystemExit` pointing at +`fetch_funsd.py` when the dataset is missing. + +### `scripts/fetch_funsd.py` — one-time data helper +`urllib` + `zipfile` download/extract of the official FUNSD zip to `data/raw/funsd/`, with a +`--url` override and a printed manual-download fallback. Not used by tests; not on the gate. + +### `tests/test_funsd_relations.py` — synthetic fixtures only (acceptance gate) +Inline tiny forms (dicts), no raw dataset. Covers: parse + gold_links; bidirectional/duplicate +dedupe; qa-link filter and direction canonicalization; all_links scope; `prf1` / `evaluate_pairs` +edge cases; same-row right-side link; below-candidate-wins-when-best; `other` never linked; +per-answer argmax picks the nearer; distance gate; header->question excluded from QA; and two +form-level `evaluate_forms` cases. + +### `notebooks/05_phase3_funsd_relations.ipynb` — Colab/local runner (no logic) +Mount/pull/fetch/test/evaluate/display + a read-only qualitative error table. + +## Out of scope (V1) +- FUNSD token classification (V2 / seqeval) — future work. +- Image/overlay loading — optional later debug aid, not in the baseline or the gate. +- Any RAG integration — Phase 4. +- sklearn / fitted rankers. + +## Result (real FUNSD, untuned a-priori params) + +Headline (held-out): `test_50.qa_links` P 0.946 / R 0.590 / **F1 0.727**. `train_149` F1 (0.665) +is below test, so there is no tuning-on-test. Recall is the design ceiling (single-link per +answer + right-side/below geometry); threshold-based multi-link is the documented next lever. +Full split x scope matrix in `DEVLOG.md` (2026-06-03) and +`outputs/evaluation/phase3_funsd_relations.json`. + +## Verification +1. **Unit (the gate):** `pytest tests/test_funsd_relations.py` green — fully synthetic, local, + no GPU/network. Full suite 236 passed. +2. **End-to-end (needs the dataset, still local/CPU):** `python scripts/fetch_funsd.py` then + `python scripts/evaluate_funsd.py` -> writes `outputs/evaluation/phase3_funsd_relations.json`. + +## Branch / workflow +- This brief is committed at `docs/phase3_brief.md`. The `plans/` directory stays gitignored + for local scratch (harness plan file, PR body draft); the canonical brief lives here. +- Branch `feature/phase3-funsd-relations` cut from `origin/main`. Entirely local phase — no + Colab round-trip needed. +- Build order (TDD): fixtures+tests -> `funsd_extraction.py` -> `eval_funsd.py` -> + `evaluate_funsd.py` -> `fetch_funsd.py` -> docs. diff --git a/notebooks/05_phase3_funsd_relations.ipynb b/notebooks/05_phase3_funsd_relations.ipynb index 4a10322..f8a4602 100644 --- a/notebooks/05_phase3_funsd_relations.ipynb +++ b/notebooks/05_phase3_funsd_relations.ipynb @@ -190,7 +190,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Step 4 - optional error peek" + "## Step 4a - Error table\n", + "\n", + "Notebook-only qualitative error analysis. This reads FUNSD JSON annotations and relation predictions; it does not write artifacts and does not affect the Phase 3 acceptance gate." ] }, { @@ -199,34 +201,185 @@ "metadata": {}, "outputs": [], "source": [ - "# Show the lowest-recall held-out forms. This is read-only and does not write artifacts.\n", + "# Build a held-out error table with missed/spurious QA links and entity text/bboxes.\n", "from src.funsd_extraction import load_funsd_split, predict_qa_links, qa_gold_links\n", "from src import config\n", "\n", + "def _ent_lookup(form):\n", + " return {e['id']: e for e in form['entities']}\n", + "\n", + "def _short(text, n=70):\n", + " text = ' '.join(str(text).split())\n", + " return text if len(text) <= n else text[:n - 1] + '...'\n", + "\n", + "def _link_text(pair, by_id):\n", + " qid, aid = pair\n", + " q = by_id[qid]\n", + " a = by_id[aid]\n", + " return f\"{qid}:{_short(q['text'])} -> {aid}:{_short(a['text'])}\"\n", + "\n", + "def _link_box(pair, by_id):\n", + " qid, aid = pair\n", + " return {\n", + " 'question_id': qid,\n", + " 'question_box': by_id[qid]['box'],\n", + " 'answer_id': aid,\n", + " 'answer_box': by_id[aid]['box'],\n", + " }\n", + "\n", "test_forms = load_funsd_split(config.FUNSD_TEST)\n", - "miss_rows = []\n", + "ERROR_FORMS = {}\n", + "ERROR_ROWS = []\n", + "ERROR_LINK_ROWS = []\n", + "\n", "for form in test_forms:\n", + " by_id = _ent_lookup(form)\n", " pred = predict_qa_links(form)\n", " gold = qa_gold_links(form)\n", - " tp = len(pred & gold)\n", - " recall = tp / len(gold) if gold else 0.0\n", + " correct = pred & gold\n", + " missed = gold - pred\n", + " spurious = pred - gold\n", + " tp = len(correct)\n", " precision = tp / len(pred) if pred else 0.0\n", - " miss_rows.append({\n", + " recall = tp / len(gold) if gold else 0.0\n", + " ERROR_FORMS[form['form_id']] = {\n", + " 'form': form,\n", + " 'pred': pred,\n", + " 'gold': gold,\n", + " 'correct': correct,\n", + " 'missed': missed,\n", + " 'spurious': spurious,\n", + " 'precision': precision,\n", + " 'recall': recall,\n", + " }\n", + " ERROR_ROWS.append({\n", " 'form_id': form['form_id'],\n", " 'precision': round(precision, 3),\n", " 'recall': round(recall, 3),\n", " 'tp': tp,\n", " 'pred': len(pred),\n", " 'gold': len(gold),\n", + " 'missed': len(missed),\n", + " 'spurious': len(spurious),\n", + " 'missed_links': '; '.join(_link_text(pair, by_id) for pair in sorted(missed))[:500],\n", + " 'spurious_links': '; '.join(_link_text(pair, by_id) for pair in sorted(spurious))[:500],\n", " })\n", + " for kind, pairs in [('missed', missed), ('spurious', spurious), ('correct', correct)]:\n", + " for pair in sorted(pairs):\n", + " boxes = _link_box(pair, by_id)\n", + " ERROR_LINK_ROWS.append({\n", + " 'form_id': form['form_id'],\n", + " 'kind': kind,\n", + " 'link': _link_text(pair, by_id),\n", + " **boxes,\n", + " })\n", + "\n", + "ERROR_ROWS = sorted(ERROR_ROWS, key=lambda r: (r['recall'], r['precision'], r['form_id']))\n", + "WORST_FORM_IDS = [r['form_id'] for r in ERROR_ROWS[:10]]\n", + "print('worst form ids:', WORST_FORM_IDS)\n", "\n", - "miss_rows = sorted(miss_rows, key=lambda r: (r['recall'], r['precision'], r['form_id']))[:10]\n", "try:\n", " import pandas as pd\n", - " display(pd.DataFrame(miss_rows))\n", + " display(pd.DataFrame(ERROR_ROWS[:10]))\n", + " display(pd.DataFrame(ERROR_LINK_ROWS).query(\"form_id in @WORST_FORM_IDS and kind != 'correct'\"))\n", "except Exception:\n", - " for row in miss_rows:\n", - " print(row)" + " for row in ERROR_ROWS[:10]:\n", + " print(row)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4b - Visualize one error case\n", + "\n", + "Set `FORM_ID` to one held-out form id. The overlay is for human debugging only: missed gold links are red, spurious predicted links are orange, and correct links are green. Question boxes are blue and answer boxes are purple." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize one held-out error case with the raw FUNSD image and relation overlay.\n", + "from pathlib import Path\n", + "from PIL import Image, ImageDraw, ImageFont\n", + "from IPython.display import display\n", + "\n", + "FORM_ID = WORST_FORM_IDS[0] if WORST_FORM_IDS else None # or set manually, e.g. '82092117'\n", + "\n", + "def _center(box):\n", + " x0, y0, x1, y1 = box\n", + " return ((x0 + x1) / 2, (y0 + y1) / 2)\n", + "\n", + "def _draw_entity(draw, entity, outline):\n", + " box = entity['box']\n", + " draw.rectangle(box, outline=outline, width=2)\n", + " label = f\"{entity['id']} {entity['label']} {_short(entity['text'], 28)}\"\n", + " x0, y0, *_ = box\n", + " draw.text((x0, max(0, y0 - 12)), label, fill=outline)\n", + "\n", + "def render_funsd_overlay(form_id, max_links_per_kind=40):\n", + " if form_id not in ERROR_FORMS:\n", + " raise ValueError(f\"unknown test form id: {form_id!r}\")\n", + " info = ERROR_FORMS[form_id]\n", + " form = info['form']\n", + " by_id = _ent_lookup(form)\n", + " img_path = config.FUNSD_ROOT / 'testing_data' / 'images' / f\"{form_id}.png\"\n", + " if not img_path.exists():\n", + " raise FileNotFoundError(f\"image not found: {img_path}\")\n", + "\n", + " img = Image.open(img_path).convert('RGB')\n", + " draw = ImageDraw.Draw(img)\n", + " colors = {\n", + " 'correct': (32, 160, 80),\n", + " 'missed': (220, 40, 40),\n", + " 'spurious': (245, 140, 30),\n", + " 'question': (40, 110, 220),\n", + " 'answer': (145, 70, 210),\n", + " }\n", + "\n", + " involved = set()\n", + " for kind in ['correct', 'missed', 'spurious']:\n", + " pairs = list(sorted(info[kind]))[:max_links_per_kind]\n", + " for qid, aid in pairs:\n", + " q = by_id[qid]\n", + " a = by_id[aid]\n", + " involved.update([qid, aid])\n", + " draw.line([_center(q['box']), _center(a['box'])], fill=colors[kind], width=3)\n", + "\n", + " for eid in involved:\n", + " entity = by_id[eid]\n", + " outline = colors['question'] if entity['label'] == 'question' else colors['answer']\n", + " _draw_entity(draw, entity, outline)\n", + "\n", + " print(f\"{form_id}: precision={info['precision']:.3f} recall={info['recall']:.3f} \"\n", + " f\"correct={len(info['correct'])} missed={len(info['missed'])} spurious={len(info['spurious'])}\")\n", + " print('image:', img_path)\n", + " return img\n", + "\n", + "display(render_funsd_overlay(FORM_ID))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4c - Batch preview worst N\n", + "\n", + "Shows a small set of worst held-out overlays inline. Keep `N_PREVIEW` small so the notebook stays responsive." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "N_PREVIEW = 3\n", + "for form_id in WORST_FORM_IDS[:N_PREVIEW]:\n", + " display(render_funsd_overlay(form_id))\n" ] } ], diff --git a/src/eval_funsd.py b/src/eval_funsd.py index 56f9768..7c91b19 100644 --- a/src/eval_funsd.py +++ b/src/eval_funsd.py @@ -8,7 +8,7 @@ - evaluate_pairs: micro P/R/F1 over many prebuilt (pred, gold) pairs (trivially unit-tested). - evaluate_forms: runs the predictor over forms for a scope, then delegates to evaluate_pairs. -Scopes (see plans/phase3-funsd-relations.md): +Scopes (see docs/phase3_brief.md): - "qa": predicted directed (question_id, answer_id) vs qa_gold_links (primary). - "all": the same QA predictions cast to undirected frozensets vs all_gold_links - a coverage diagnostic ("how many of ALL GT links does the QA-only heuristic recover"), not a second diff --git a/src/funsd_extraction.py b/src/funsd_extraction.py index e78564b..4478bf3 100644 --- a/src/funsd_extraction.py +++ b/src/funsd_extraction.py @@ -13,7 +13,7 @@ GT links are stored undirected (FUNSD records a link on both endpoints); `qa_gold_links` canonicalizes question+answer pairs to a directed (question_id, answer_id), `all_gold_links` -keeps the full undirected set. See plans/phase3-funsd-relations.md. +keeps the full undirected set. See docs/phase3_brief.md. """ from __future__ import annotations