Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DEVLOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ Reading it honestly:
QA-only heuristic cannot cover header->question and other link types. `all_199` carries the
"contains the 50 test + 149 tuned forms, not held-out" caveat in the report JSON.

Design (locked in discussion; see `plans/phase3-funsd-relations.md`):
Design (locked in discussion; see `docs/phase3_brief.md`):
- **Predictor:** per-answer argmax + distance gate; distances normalized by the form's median
entity height; two separate knobs (`max_distance_units` distance gate, `min_score` floor).
- **GT links:** deduped to undirected frozensets (FUNSD records links bidirectionally), then
Expand Down
141 changes: 141 additions & 0 deletions docs/phase3_brief.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Phase 3 — FUNSD relation branch (V1 baseline)

> Implementation brief for Phase 3. Committed in the repo (travels with `git pull` to Colab)
> so the references to it in `DEVLOG.md` and the `src/funsd_extraction.py` /
> `src/eval_funsd.py` module docstrings resolve. Status: V1 implemented and scored on
> `feature/phase3-funsd-relations`; headline held-out `test_50.qa_links` F1 0.727.

## Context

Phases 0-2 are merged to `main` (FinTabNet.c table topology + OCR content + table-only RAG +
DocLayNet layout crop). Phase 3 is a **FUNSD relation-linking baseline** over GT entities.

It is a deliberately standalone branch — it does **not** touch the RAG pipeline and only
wires into the demo in Phase 4. It is **annotation-only and CPU-only**: the FUNSD
annotation JSON already carries each entity's text, bbox, label, and GT `linking` pairs,
so scoring, the spatial heuristic, and the gold links never load image pixels. No GPU, no
Colab — this is the local "logic loop", fast `pytest`.

The task: given GT entities `{question, answer, header, other}`, predict which entities are
linked, scored P/R/F1 against GT `linking`.

## Locked design decisions

- **Predictor:** deterministic spatial heuristic, **per-answer argmax + distance gate**.
Each `answer` picks its single highest-scoring `question`; emit the link only if the best
score clears the gate. Distances normalized by the form's **median entity height** so one
gate works across differently-scaled scans. (Rejected: per-question argmax -> under-predicts
multi-answer questions; global threshold -> precision-fragile, needs a tuned cutoff.)
- **GT links:** dedupe `linking` to **undirected** `frozenset({id_a, id_b})` per form (FUNSD
records links bidirectionally/duplicated). Then derive scopes:
- `qa_links` (**primary**): pairs whose endpoints are one `question` + one `answer`,
canonicalized to directed `(question_id, answer_id)`.
- `all_links` (**secondary**): the full deduped undirected set, scored as frozensets.
- **Eval split / reporting matrix:**
- **Primary headline:** `test_50.qa_links.micro_f1` (official 50-form test split).
- Set/tune any heuristic params on **train_149 only**; never on the reported test set.
- **Secondary:** `all_199.qa_links`, `test_50.all_links`, `all_199.all_links`. Print
`all_199` with a "contains the 50 test + 149 tuned forms, not held-out" caveat.
- `debug_20` (first 20 train forms) is for parser/CLI smoke only, never for tuning.
- **all_links is a coverage diagnostic, not a second predictor.** V1 predicts only `q->a`;
the `all_links` row scores those same QA predictions (as undirected frozensets) against the
full GT link set, i.e. "what fraction of all GT links does the QA-only heuristic cover."
- **No sklearn in V1.** P/R/F1 is set arithmetic (`len(pred & gold)/len(pred)` etc.). sklearn
enters only if a fitted ranker is added later, and only fit on train_149.
- **Data:** raw FUNSD zip -> `data/raw/funsd/...` (gitignored). Tests use **synthetic fixtures
only**, never the raw dataset.

## Files

### `src/config.py` — FUNSD paths
```python
FUNSD_ROOT = DATA_ROOT / "raw" / "funsd" / "dataset"
FUNSD_TRAIN = FUNSD_ROOT / "training_data" / "annotations" # 149 forms
FUNSD_TEST = FUNSD_ROOT / "testing_data" / "annotations" # 50 forms
```
Output reuses existing `config.EVALUATION`.

### `src/funsd_extraction.py` — data contract + baseline predictor
Follows the `from __future__ import annotations` + TypedDict + pure-function style of
`src/canonical_schema.py` / `src/eval_retrieval.py`.
- `FunsdEntity` TypedDict: `id, label, text, box [x0,y0,x1,y1]`.
- `FunsdForm` TypedDict: `form_id, entities, gold_links (set[frozenset[int]])`. In-memory it
is a `set` (dedupe is its nature); cast to `list` only when serializing JSON output.
- `parse_funsd_form(data, form_id)` / `parse_funsd_json(path)`: normalize entities, collect
`linking`, dedupe to undirected frozensets, drop self-links and links to missing ids.
- `load_funsd_split(annotations_dir)`.
- `qa_gold_links(form) -> set[tuple[int,int]]`: question+answer pairs -> directed `(q,a)`.
- `all_gold_links(form) -> set[frozenset]`: the deduped undirected set.
- `HeuristicParams` (frozen dataclass, a-priori defaults; the documented tunable surface,
train_149 only). Two clearly-separated knobs, not one fuzzy "gate":
- `max_distance_units`: **distance gate** — a (Q, A) candidate is rejected if its
median-height-normalized distance exceeds this. Filters the candidate set.
- `min_score`: **score threshold** — the per-answer argmax winner is emitted only if its
final score clears this. Acceptance test on the chosen link.
- plus right-band tolerance, below-gap tolerance, and the boost weights.
- `predict_qa_links(form, params=HeuristicParams()) -> set[tuple[int,int]]`: per-answer argmax.
For each answer A, score **every** question candidate (distances scaled by
`median_entity_height(entities)`), drop candidates beyond `max_distance_units`, take the
highest-scoring question, and emit `(q,a)` only if that score >= `min_score`. Geometry:
- **same-row right-side** (A vertically within Q's band, A to Q's right): strongest score.
- **below** (A under Q, horizontally aligned/overlapping): fallback score.
- proximity + alignment are additive boosts. (below and right-side compete in the same
argmax — a below candidate wins whenever it is the best valid candidate.)

### `src/eval_funsd.py` — custom set-based metrics
Split so the metric stays pure (no predictor/params dependency) and the form-runner is separate:
- `prf1(pred, gold) -> dict`: one pred/gold set -> tp/precision/recall/f1, zero-guards.
- `evaluate_pairs(per_form) -> dict`: **pure** micro P/R/F1 over prebuilt (pred, gold) pairs +
counts. No predictor inside, so it is trivially unit-tested with synthetic sets.
- `evaluate_forms(forms, scope, params=HeuristicParams()) -> dict`: builds per-form (pred, gold)
for `scope in {"qa","all"}` — qa = directed `(q,a)` tuples from `predict_qa_links` vs
`qa_gold_links`; all = those predictions cast to undirected frozensets vs `all_gold_links` —
then delegates to `evaluate_pairs`.

### `scripts/evaluate_funsd.py` — CLI runner
Mirrors `scripts/evaluate_rag.py`. Loads `train_149` / `test_50`, builds `all_199` and
`debug_20`, runs split x scope, writes `config.EVALUATION / "phase3_funsd_relations.json"`,
prints the headline `test_50.qa_links` + secondaries. Guards with a `SystemExit` pointing at
`fetch_funsd.py` when the dataset is missing.

### `scripts/fetch_funsd.py` — one-time data helper
`urllib` + `zipfile` download/extract of the official FUNSD zip to `data/raw/funsd/`, with a
`--url` override and a printed manual-download fallback. Not used by tests; not on the gate.

### `tests/test_funsd_relations.py` — synthetic fixtures only (acceptance gate)
Inline tiny forms (dicts), no raw dataset. Covers: parse + gold_links; bidirectional/duplicate
dedupe; qa-link filter and direction canonicalization; all_links scope; `prf1` / `evaluate_pairs`
edge cases; same-row right-side link; below-candidate-wins-when-best; `other` never linked;
per-answer argmax picks the nearer; distance gate; header->question excluded from QA; and two
form-level `evaluate_forms` cases.

### `notebooks/05_phase3_funsd_relations.ipynb` — Colab/local runner (no logic)
Mount/pull/fetch/test/evaluate/display + a read-only qualitative error table.

## Out of scope (V1)
- FUNSD token classification (V2 / seqeval) — future work.
- Image/overlay loading — optional later debug aid, not in the baseline or the gate.
- Any RAG integration — Phase 4.
- sklearn / fitted rankers.

## Result (real FUNSD, untuned a-priori params)

Headline (held-out): `test_50.qa_links` P 0.946 / R 0.590 / **F1 0.727**. `train_149` F1 (0.665)
is below test, so there is no tuning-on-test. Recall is the design ceiling (single-link per
answer + right-side/below geometry); threshold-based multi-link is the documented next lever.
Full split x scope matrix in `DEVLOG.md` (2026-06-03) and
`outputs/evaluation/phase3_funsd_relations.json`.

## Verification
1. **Unit (the gate):** `pytest tests/test_funsd_relations.py` green — fully synthetic, local,
no GPU/network. Full suite 236 passed.
2. **End-to-end (needs the dataset, still local/CPU):** `python scripts/fetch_funsd.py` then
`python scripts/evaluate_funsd.py` -> writes `outputs/evaluation/phase3_funsd_relations.json`.

## Branch / workflow
- This brief is committed at `docs/phase3_brief.md`. The `plans/` directory stays gitignored
for local scratch (harness plan file, PR body draft); the canonical brief lives here.
- Branch `feature/phase3-funsd-relations` cut from `origin/main`. Entirely local phase — no
Colab round-trip needed.
- Build order (TDD): fixtures+tests -> `funsd_extraction.py` -> `eval_funsd.py` ->
`evaluate_funsd.py` -> `fetch_funsd.py` -> docs.
173 changes: 163 additions & 10 deletions notebooks/05_phase3_funsd_relations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 4 - optional error peek"
"## Step 4a - Error table\n",
"\n",
"Notebook-only qualitative error analysis. This reads FUNSD JSON annotations and relation predictions; it does not write artifacts and does not affect the Phase 3 acceptance gate."
]
},
{
Expand All @@ -199,34 +201,185 @@
"metadata": {},
"outputs": [],
"source": [
"# Show the lowest-recall held-out forms. This is read-only and does not write artifacts.\n",
"# Build a held-out error table with missed/spurious QA links and entity text/bboxes.\n",
"from src.funsd_extraction import load_funsd_split, predict_qa_links, qa_gold_links\n",
"from src import config\n",
"\n",
"def _ent_lookup(form):\n",
" return {e['id']: e for e in form['entities']}\n",
"\n",
"def _short(text, n=70):\n",
" text = ' '.join(str(text).split())\n",
" return text if len(text) <= n else text[:n - 1] + '...'\n",
"\n",
"def _link_text(pair, by_id):\n",
" qid, aid = pair\n",
" q = by_id[qid]\n",
" a = by_id[aid]\n",
" return f\"{qid}:{_short(q['text'])} -> {aid}:{_short(a['text'])}\"\n",
"\n",
"def _link_box(pair, by_id):\n",
" qid, aid = pair\n",
" return {\n",
" 'question_id': qid,\n",
" 'question_box': by_id[qid]['box'],\n",
" 'answer_id': aid,\n",
" 'answer_box': by_id[aid]['box'],\n",
" }\n",
"\n",
"test_forms = load_funsd_split(config.FUNSD_TEST)\n",
"miss_rows = []\n",
"ERROR_FORMS = {}\n",
"ERROR_ROWS = []\n",
"ERROR_LINK_ROWS = []\n",
"\n",
"for form in test_forms:\n",
" by_id = _ent_lookup(form)\n",
" pred = predict_qa_links(form)\n",
" gold = qa_gold_links(form)\n",
" tp = len(pred & gold)\n",
" recall = tp / len(gold) if gold else 0.0\n",
" correct = pred & gold\n",
" missed = gold - pred\n",
" spurious = pred - gold\n",
" tp = len(correct)\n",
" precision = tp / len(pred) if pred else 0.0\n",
" miss_rows.append({\n",
" recall = tp / len(gold) if gold else 0.0\n",
" ERROR_FORMS[form['form_id']] = {\n",
" 'form': form,\n",
" 'pred': pred,\n",
" 'gold': gold,\n",
" 'correct': correct,\n",
" 'missed': missed,\n",
" 'spurious': spurious,\n",
" 'precision': precision,\n",
" 'recall': recall,\n",
" }\n",
" ERROR_ROWS.append({\n",
" 'form_id': form['form_id'],\n",
" 'precision': round(precision, 3),\n",
" 'recall': round(recall, 3),\n",
" 'tp': tp,\n",
" 'pred': len(pred),\n",
" 'gold': len(gold),\n",
" 'missed': len(missed),\n",
" 'spurious': len(spurious),\n",
" 'missed_links': '; '.join(_link_text(pair, by_id) for pair in sorted(missed))[:500],\n",
" 'spurious_links': '; '.join(_link_text(pair, by_id) for pair in sorted(spurious))[:500],\n",
" })\n",
" for kind, pairs in [('missed', missed), ('spurious', spurious), ('correct', correct)]:\n",
" for pair in sorted(pairs):\n",
" boxes = _link_box(pair, by_id)\n",
" ERROR_LINK_ROWS.append({\n",
" 'form_id': form['form_id'],\n",
" 'kind': kind,\n",
" 'link': _link_text(pair, by_id),\n",
" **boxes,\n",
" })\n",
"\n",
"ERROR_ROWS = sorted(ERROR_ROWS, key=lambda r: (r['recall'], r['precision'], r['form_id']))\n",
"WORST_FORM_IDS = [r['form_id'] for r in ERROR_ROWS[:10]]\n",
"print('worst form ids:', WORST_FORM_IDS)\n",
"\n",
"miss_rows = sorted(miss_rows, key=lambda r: (r['recall'], r['precision'], r['form_id']))[:10]\n",
"try:\n",
" import pandas as pd\n",
" display(pd.DataFrame(miss_rows))\n",
" display(pd.DataFrame(ERROR_ROWS[:10]))\n",
" display(pd.DataFrame(ERROR_LINK_ROWS).query(\"form_id in @WORST_FORM_IDS and kind != 'correct'\"))\n",
"except Exception:\n",
" for row in miss_rows:\n",
" print(row)"
" for row in ERROR_ROWS[:10]:\n",
" print(row)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 4b - Visualize one error case\n",
"\n",
"Set `FORM_ID` to one held-out form id. The overlay is for human debugging only: missed gold links are red, spurious predicted links are orange, and correct links are green. Question boxes are blue and answer boxes are purple."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualize one held-out error case with the raw FUNSD image and relation overlay.\n",
"from pathlib import Path\n",
"from PIL import Image, ImageDraw, ImageFont\n",
"from IPython.display import display\n",
"\n",
"FORM_ID = WORST_FORM_IDS[0] if WORST_FORM_IDS else None # or set manually, e.g. '82092117'\n",
"\n",
"def _center(box):\n",
" x0, y0, x1, y1 = box\n",
" return ((x0 + x1) / 2, (y0 + y1) / 2)\n",
"\n",
"def _draw_entity(draw, entity, outline):\n",
" box = entity['box']\n",
" draw.rectangle(box, outline=outline, width=2)\n",
" label = f\"{entity['id']} {entity['label']} {_short(entity['text'], 28)}\"\n",
" x0, y0, *_ = box\n",
" draw.text((x0, max(0, y0 - 12)), label, fill=outline)\n",
"\n",
"def render_funsd_overlay(form_id, max_links_per_kind=40):\n",
" if form_id not in ERROR_FORMS:\n",
" raise ValueError(f\"unknown test form id: {form_id!r}\")\n",
" info = ERROR_FORMS[form_id]\n",
" form = info['form']\n",
" by_id = _ent_lookup(form)\n",
" img_path = config.FUNSD_ROOT / 'testing_data' / 'images' / f\"{form_id}.png\"\n",
" if not img_path.exists():\n",
" raise FileNotFoundError(f\"image not found: {img_path}\")\n",
"\n",
" img = Image.open(img_path).convert('RGB')\n",
" draw = ImageDraw.Draw(img)\n",
" colors = {\n",
" 'correct': (32, 160, 80),\n",
" 'missed': (220, 40, 40),\n",
" 'spurious': (245, 140, 30),\n",
" 'question': (40, 110, 220),\n",
" 'answer': (145, 70, 210),\n",
" }\n",
"\n",
" involved = set()\n",
" for kind in ['correct', 'missed', 'spurious']:\n",
" pairs = list(sorted(info[kind]))[:max_links_per_kind]\n",
" for qid, aid in pairs:\n",
" q = by_id[qid]\n",
" a = by_id[aid]\n",
" involved.update([qid, aid])\n",
" draw.line([_center(q['box']), _center(a['box'])], fill=colors[kind], width=3)\n",
"\n",
" for eid in involved:\n",
" entity = by_id[eid]\n",
" outline = colors['question'] if entity['label'] == 'question' else colors['answer']\n",
" _draw_entity(draw, entity, outline)\n",
"\n",
" print(f\"{form_id}: precision={info['precision']:.3f} recall={info['recall']:.3f} \"\n",
" f\"correct={len(info['correct'])} missed={len(info['missed'])} spurious={len(info['spurious'])}\")\n",
" print('image:', img_path)\n",
" return img\n",
"\n",
"display(render_funsd_overlay(FORM_ID))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 4c - Batch preview worst N\n",
"\n",
"Shows a small set of worst held-out overlays inline. Keep `N_PREVIEW` small so the notebook stays responsive."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"N_PREVIEW = 3\n",
"for form_id in WORST_FORM_IDS[:N_PREVIEW]:\n",
" display(render_funsd_overlay(form_id))\n"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion src/eval_funsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
- evaluate_pairs: micro P/R/F1 over many prebuilt (pred, gold) pairs (trivially unit-tested).
- evaluate_forms: runs the predictor over forms for a scope, then delegates to evaluate_pairs.

Scopes (see plans/phase3-funsd-relations.md):
Scopes (see docs/phase3_brief.md):
- "qa": predicted directed (question_id, answer_id) vs qa_gold_links (primary).
- "all": the same QA predictions cast to undirected frozensets vs all_gold_links - a coverage
diagnostic ("how many of ALL GT links does the QA-only heuristic recover"), not a second
Expand Down
2 changes: 1 addition & 1 deletion src/funsd_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

GT links are stored undirected (FUNSD records a link on both endpoints); `qa_gold_links`
canonicalizes question+answer pairs to a directed (question_id, answer_id), `all_gold_links`
keeps the full undirected set. See plans/phase3-funsd-relations.md.
keeps the full undirected set. See docs/phase3_brief.md.
"""

from __future__ import annotations
Expand Down
Loading