From 189e97ac6f387aa7b24986864155de9a9461adf6 Mon Sep 17 00:00:00 2001 From: AD2000X Date: Wed, 3 Jun 2026 12:32:40 +0100 Subject: [PATCH] feat: Phase 3 FUNSD relation-linking baseline (V1) Annotation-only, CPU-only relation P/R/F1 over GT entities (no image pixels, no GPU). Held-out test_50.qa_links F1 0.727 (P 0.946 / R 0.590), untuned a-priori params; train_149 F1 0.665 < test, so no tuning-on-test. Bundled: - src/funsd_extraction.py: parse + undirected link dedupe + qa/all scopes + per-answer-argmax + distance-gate predictor (max_distance_units, min_score) - src/eval_funsd.py: set-based prf1 / evaluate_pairs / evaluate_forms - scripts/evaluate_funsd.py: split x scope CLI -> phase3_funsd_relations.json - scripts/fetch_funsd.py: one-time FUNSD download/extract helper - tests/test_funsd_relations.py: 17 synthetic tests (acceptance gate) - notebooks/05_phase3_funsd_relations.ipynb: Colab/local runner (no logic) - src/config.py: FUNSD paths; .gitignore: plans/ - DEVLOG.md + PLAN.md: Phase 3 result + section 7 next-steps refresh --- .gitignore | 1 + DEVLOG.md | 50 +++++ PLAN.md | 33 +-- notebooks/05_phase3_funsd_relations.ipynb | 254 ++++++++++++++++++++++ scripts/evaluate_funsd.py | 93 ++++++++ scripts/fetch_funsd.py | 68 ++++++ src/config.py | 8 + src/eval_funsd.py | 69 ++++++ src/funsd_extraction.py | 210 ++++++++++++++++++ tests/test_funsd_relations.py | 216 ++++++++++++++++++ 10 files changed, 989 insertions(+), 13 deletions(-) create mode 100644 notebooks/05_phase3_funsd_relations.ipynb create mode 100644 scripts/evaluate_funsd.py create mode 100644 scripts/fetch_funsd.py create mode 100644 src/eval_funsd.py create mode 100644 src/funsd_extraction.py create mode 100644 tests/test_funsd_relations.py diff --git a/.gitignore b/.gitignore index 453259d..d4bef22 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ venv/ .mcp.json tmp/ memory/ +plans/ # Scratch screenshots (never committed) output*.png diff --git a/DEVLOG.md b/DEVLOG.md index 3e5e30f..0b47483 100644 --- a/DEVLOG.md +++ b/DEVLOG.md @@ -181,6 +181,56 @@ Decisions outgrow this file, split them into `DECISIONS.md` (or `docs/adr/`). --- +## 2026-06-03 - Phase 3 FUNSD relation-linking baseline (V1) + +### Result - annotation-only spatial heuristic; high precision, recall is the design ceiling + +First Phase 3 deliverable: a deterministic FUNSD relation-linking baseline over GT entities, +CPU-only and annotation-only (the FUNSD JSON carries entity text/bbox/label and the GT +`linking` pairs, so no image pixels are loaded). Run on the real dataset (149 train + 50 test += 199 forms), `scripts/evaluate_funsd.py`, untuned a-priori params: + +| split | scope | precision | recall | f1 | tp / pred / gold | n | +| --- | --- | --- | --- | --- | --- | --- | +| **test_50** | **qa_links** | **0.946** | **0.590** | **0.727** | 494 / 522 / 837 | 50 | +| all_199 | qa_links | 0.925 | 0.535 | 0.678 | 2123 / 2295 / 3966 | 199 | +| test_50 | all_links | 0.946 | 0.464 | 0.623 | 494 / 522 / 1064 | 50 | +| all_199 | all_links | 0.925 | 0.401 | 0.560 | 2123 / 2295 / 5293 | 199 | +| train_149 | qa_links | 0.919 | 0.521 | 0.665 | 1629 / 1773 / 3129 | 149 | + +Reading it honestly: +- **Headline (held-out): `test_50.qa_links.micro_f1` = 0.727**, precision 0.946. The heuristic + fires conservatively and is right when it does; the limit is recall. +- **Recall (0.590) is the design ceiling, not a bug.** Per-answer argmax emits at most one link + per answer, and the geometry only models same-row right-side and below relations - so answers + whose question sits left/above, or that have multiple gold questions, are under-covered. The + rejected alternatives (per-question argmax, global threshold) trade this for precision; richer + matching (threshold-based multi-link) is the documented next lever, deliberately out of V1. +- **No tuning-on-test risk.** Params are a-priori defaults; `train_149` F1 (0.665) is *below* + `test_50` (0.727), so test is if anything the easier split - the gap is sampling, not fitting. +- **`all_links` is a coverage diagnostic, not a second predictor.** Same QA predictions scored + (as undirected pairs) against every GT link; recall necessarily drops (0.464 test) because the + QA-only heuristic cannot cover header->question and other link types. `all_199` carries the + "contains the 50 test + 149 tuned forms, not held-out" caveat in the report JSON. + +Design (locked in discussion; see `plans/phase3-funsd-relations.md`): +- **Predictor:** per-answer argmax + distance gate; distances normalized by the form's median + entity height; two separate knobs (`max_distance_units` distance gate, `min_score` floor). +- **GT links:** deduped to undirected frozensets (FUNSD records links bidirectionally), then + `qa_gold_links` canonicalizes question+answer to directed `(q,a)`; `all_gold_links` keeps the + full undirected set. +- **Reporting matrix:** primary `test_50.qa_links` (held-out); secondary `all_199.qa_links`, + `test_50/all_199.all_links`; `train_149` is the dev/tuning split. +- **No sklearn in V1** (P/R/F1 is set arithmetic). **No image loading** (optional later for + qualitative overlay/debug only, never in the baseline or the gate). +- **Scope held:** standalone branch; does not touch the RAG pipeline. FUNSD token classification + (V2 / seqeval) is future work. +- **Files:** `src/funsd_extraction.py`, `src/eval_funsd.py`, `scripts/evaluate_funsd.py`, + `scripts/fetch_funsd.py`, `tests/test_funsd_relations.py` (17 synthetic tests, the gate), + `src/config.py` (FUNSD paths). Full suite 236 passed. + +--- + ## 2026-06-02 - Phase 2 DocLayNet layout-crop MVP gate ### Finding - Aryn primary carries forward; fallback is narrow; crop->structure needs band dedup diff --git a/PLAN.md b/PLAN.md index 326222b..d417878 100644 --- a/PLAN.md +++ b/PLAN.md @@ -542,19 +542,26 @@ Implementation details: ## 7. Next steps -**Phases 0 through 1C are complete and merged** (v1 = table-only RAG). The active branch is -**Phase 2 (DocLayNet layout integration)**. The detector is pinned, the pure geometry/layout -modules are tested, the fixed DocLayNet MVP subset is scored, and the crop->TATR structure -handoff is validated on sampled crops. - -Remaining Phase 2 close-out: - -1. Pull the branch on Colab and rerun Step 7d with the tightened empty-grid validator - (`scripts/smoke_structure.py --n 286 --seed 42`). Expected: 285 OK / 1 WARN, still under - the <=5% WARN gate. -2. Record the confirmed Step 7d result in `DEVLOG.md`, then open the Phase 2 PR. -3. After merge, repin Colab notebooks to `main`; do not start Phase 3 until the Phase 2 PR is - merged or explicitly paused. +**Phases 0 through 2 are complete and merged** (v1 = table-only RAG; Phase 2 = DocLayNet +layout-crop integration, merged to `main` 2026-06-03). The active branch is +**Phase 3 (FUNSD relation branch)**, `feature/phase3-funsd-relations` off `main`. + +Phase 3 V1 is implemented and scored (entirely local, CPU-only, no Colab): + +1. Annotation-only deterministic relation baseline: `src/funsd_extraction.py` (parse + dedupe + + per-answer-argmax predictor), `src/eval_funsd.py` (set-based P/R/F1), + `scripts/evaluate_funsd.py` (+ `scripts/fetch_funsd.py`), `tests/test_funsd_relations.py` + (17 synthetic tests). Full suite 236 passed. +2. Headline (held-out `test_50.qa_links`): P 0.946 / R 0.590 / **F1 0.727**; secondaries in + `DEVLOG.md` (2026-06-03) and `outputs/evaluation/phase3_funsd_relations.json`. + +Remaining: + +1. Open the Phase 3 PR. +2. Optional (train-only): tune `HeuristicParams` on `train_149` if higher recall is wanted; + never on `test_50`. FUNSD token classification (V2 / seqeval) and threshold-based multi-link + matching are future work, not V1. +3. Phase 4 (full demo + evaluation + report) is the next phase. --- diff --git a/notebooks/05_phase3_funsd_relations.ipynb b/notebooks/05_phase3_funsd_relations.ipynb new file mode 100644 index 0000000..4a10322 --- /dev/null +++ b/notebooks/05_phase3_funsd_relations.ipynb @@ -0,0 +1,254 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Phase 3 - FUNSD relation baseline (Colab runner)\n", + "\n", + "Runner only: mount Drive, pull the Phase 3 branch, fetch the raw FUNSD annotations, run the synthetic unit gate, then run the real FUNSD relation evaluation.\n", + "\n", + "Phase 3 is annotation-only and CPU-only. The FUNSD JSON carries entity text, bbox, label, and GT linking pairs, so this notebook does not load image pixels and does not need a GPU. Logic lives in `src/` and `scripts/`, not in this notebook.\n", + "\n", + "Before running in Colab, make sure `feature/phase3-funsd-relations` has been pushed to GitHub. After Phase 3 merges, set `BRANCH = 'main'` in the boot cell." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Boot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 1. Mount Drive so config.DATA_ROOT and config.OUTPUT_ROOT persist across Colab sessions.\n", + "from google.colab import drive\n", + "drive.mount('/content/drive')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 2. Get the code onto the VM and pin the Phase 3 branch.\n", + "import os\n", + "\n", + "REPO = '/content/FinDocStructRAG'\n", + "BRANCH = 'feature/phase3-funsd-relations' # change to 'main' after Phase 3 merges\n", + "\n", + "if not os.path.isdir(f'{REPO}/.git'):\n", + " !git clone --quiet https://github.com/AD2000X/FinDocStructRAG.git {REPO}\n", + "\n", + "!cd {REPO} && git fetch origin --quiet\n", + "!cd {REPO} && git checkout {BRANCH} && git pull --ff-only origin {BRANCH}\n", + "!cd {REPO} && echo branch: $(git rev-parse --abbrev-ref HEAD) HEAD: $(git log --oneline -1)\n", + "%cd /content/FinDocStructRAG" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 3. Make src/ importable and sanity-check the Phase 3 paths.\n", + "import importlib\n", + "import sys\n", + "\n", + "sys.path.insert(0, '/content/FinDocStructRAG')\n", + "from src import config\n", + "importlib.reload(config)\n", + "\n", + "print('IN_COLAB :', config.IN_COLAB)\n", + "print('DATA_ROOT :', config.DATA_ROOT)\n", + "print('OUTPUT_ROOT :', config.OUTPUT_ROOT)\n", + "print('FUNSD_ROOT :', config.FUNSD_ROOT)\n", + "print('FUNSD_TRAIN :', config.FUNSD_TRAIN)\n", + "print('FUNSD_TEST :', config.FUNSD_TEST)\n", + "print('EVALUATION :', config.EVALUATION)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1 - fetch or reuse FUNSD annotations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Downloads the official FUNSD zip only if it is not already present on Drive.\n", + "# It extracts to data/raw/funsd/dataset/...; tests never depend on this data.\n", + "!python scripts/fetch_funsd.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Dataset count sanity check. Expected: 149 train + 50 test annotations.\n", + "import importlib\n", + "from src import config\n", + "importlib.reload(config)\n", + "\n", + "n_train = len(list(config.FUNSD_TRAIN.glob('*.json')))\n", + "n_test = len(list(config.FUNSD_TEST.glob('*.json')))\n", + "print('train annotations:', n_train)\n", + "print('test annotations :', n_test)\n", + "assert n_train == 149 and n_test == 50, 'Unexpected FUNSD annotation counts'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2 - unit gate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Small dependency for the synthetic acceptance tests. The Phase 3 runtime itself is stdlib-only.\n", + "!python -m pip install -q pytest\n", + "!python -m pytest tests/test_funsd_relations.py" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3 - run Phase 3 evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!python scripts/evaluate_funsd.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the JSON report and display the split x scope matrix.\n", + "import json\n", + "from pathlib import Path\n", + "\n", + "from src import config\n", + "\n", + "report_path = config.EVALUATION / 'phase3_funsd_relations.json'\n", + "report = json.loads(report_path.read_text(encoding='utf-8'))\n", + "print('report:', report_path)\n", + "print('primary:', report['primary'])\n", + "print('note :', report['note'])\n", + "\n", + "rows = []\n", + "for split, scopes in report['results'].items():\n", + " for scope, m in scopes.items():\n", + " rows.append({\n", + " 'split': split,\n", + " 'scope': scope,\n", + " 'precision': round(m['precision'], 3),\n", + " 'recall': round(m['recall'], 3),\n", + " 'f1': round(m['f1'], 3),\n", + " 'tp': m['tp'],\n", + " 'pred': m['n_pred'],\n", + " 'gold': m['n_gold'],\n", + " 'forms': m['num_forms'],\n", + " })\n", + "\n", + "try:\n", + " import pandas as pd\n", + " display(pd.DataFrame(rows))\n", + "except Exception:\n", + " for row in rows:\n", + " print(row)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4 - optional error peek" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show the lowest-recall held-out forms. This is read-only and does not write artifacts.\n", + "from src.funsd_extraction import load_funsd_split, predict_qa_links, qa_gold_links\n", + "from src import config\n", + "\n", + "test_forms = load_funsd_split(config.FUNSD_TEST)\n", + "miss_rows = []\n", + "for form in test_forms:\n", + " pred = predict_qa_links(form)\n", + " gold = qa_gold_links(form)\n", + " tp = len(pred & gold)\n", + " recall = tp / len(gold) if gold else 0.0\n", + " precision = tp / len(pred) if pred else 0.0\n", + " miss_rows.append({\n", + " 'form_id': form['form_id'],\n", + " 'precision': round(precision, 3),\n", + " 'recall': round(recall, 3),\n", + " 'tp': tp,\n", + " 'pred': len(pred),\n", + " 'gold': len(gold),\n", + " })\n", + "\n", + "miss_rows = sorted(miss_rows, key=lambda r: (r['recall'], r['precision'], r['form_id']))[:10]\n", + "try:\n", + " import pandas as pd\n", + " display(pd.DataFrame(miss_rows))\n", + "except Exception:\n", + " for row in miss_rows:\n", + " print(row)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/scripts/evaluate_funsd.py b/scripts/evaluate_funsd.py new file mode 100644 index 0000000..fa6b8fb --- /dev/null +++ b/scripts/evaluate_funsd.py @@ -0,0 +1,93 @@ +"""Phase 3: FUNSD relation-linking evaluation (CPU, no GPU, no network). + +Annotation-only baseline. Reads the local FUNSD annotation JSON, runs the deterministic +per-answer-argmax heuristic, and reports relation P/R/F1 across split x scope: + + python scripts/evaluate_funsd.py + +Splits: train_149 (tuning/dev), test_50 (PRIMARY headline), all_199 = train + test + (secondary, NOT held-out - it contains the 50 test + 149 tuned forms), debug_20. +Scopes: qa_links (primary, question->answer) and all_links (secondary coverage diagnostic). + +Headline metric: test_50.qa_links.micro_f1. Writes +outputs/evaluation/phase3_funsd_relations.json. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from src import config # noqa: E402 +from src.eval_funsd import evaluate_forms # noqa: E402 +from src.funsd_extraction import HeuristicParams, load_funsd_split # noqa: E402 + +SCOPES = ("qa", "all") + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--debug-n", type=int, default=20, + help="size of the debug split (first N train forms; parser/CLI smoke only)") + args = ap.parse_args() + + if not config.FUNSD_TRAIN.is_dir() or not config.FUNSD_TEST.is_dir(): + raise SystemExit( + f"FUNSD annotations not found under {config.FUNSD_ROOT}\n" + f" expected: {config.FUNSD_TRAIN} and {config.FUNSD_TEST}\n" + f" run: python scripts/fetch_funsd.py") + + train = load_funsd_split(config.FUNSD_TRAIN) + test = load_funsd_split(config.FUNSD_TEST) + debug_n = min(args.debug_n, len(train)) + splits = { + "train_149": train, + "test_50": test, + "all_199": train + test, + f"debug_{debug_n}": train[:debug_n], + } + + params = HeuristicParams() + report: dict = { + "params": params.__dict__, + "split_sizes": {name: len(forms) for name, forms in splits.items()}, + "primary": "test_50.qa_links", + "note": ("all_199 contains the 50 test + 149 tuned forms and is NOT held-out; " + "the held-out headline is test_50.qa_links. all_links is a coverage " + "diagnostic of the QA-only predictor, not a second predictor."), + "results": {}, + } + for name, forms in splits.items(): + report["results"][name] = { + f"{scope}_links": evaluate_forms(forms, scope, params) for scope in SCOPES + } + + out_dir = config.EVALUATION + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / "phase3_funsd_relations.json" + out_path.write_text(json.dumps(report, indent=2), encoding="utf-8") + + # Console summary: headline first, then secondaries. + def row(split: str, scope: str) -> str: + m = report["results"][split][f"{scope}_links"] + return (f"{split:<10} {scope+'_links':<10} " + f"P {m['precision']:.3f} R {m['recall']:.3f} F1 {m['f1']:.3f} " + f"(tp {m['tp']} / pred {m['n_pred']} / gold {m['n_gold']}, n={m['num_forms']})") + + print("HEADLINE (held-out):") + print(" " + row("test_50", "qa")) + print("\nSecondary:") + print(" " + row("all_199", "qa")) + print(" " + row("test_50", "all")) + print(" " + row("all_199", "all")) + print(" " + row("train_149", "qa") + " [dev/tuning split]") + print(f"\nreport -> {out_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/fetch_funsd.py b/scripts/fetch_funsd.py new file mode 100644 index 0000000..8325384 --- /dev/null +++ b/scripts/fetch_funsd.py @@ -0,0 +1,68 @@ +"""One-time FUNSD dataset download/extract helper (Phase 3). + +The FUNSD zip extracts a `dataset/` tree with training_data/ and testing_data/, each holding +annotations/ (the JSON V1 needs) and images/. It lands under data/raw/funsd/ (gitignored), so +config.FUNSD_ROOT resolves to data/raw/funsd/dataset. + + python scripts/fetch_funsd.py # download + extract + python scripts/fetch_funsd.py --url # if the default URL is unreachable + +If the download fails (the host is occasionally down), grab dataset.zip manually and unzip it +into data/raw/funsd/ so that data/raw/funsd/dataset/training_data/annotations/ exists. Tests +never touch this; it only feeds scripts/evaluate_funsd.py. +""" + +from __future__ import annotations + +import argparse +import sys +import tempfile +import urllib.request +import zipfile +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from src import config # noqa: E402 + +DEFAULT_URL = "https://guillaumejaume.github.io/FUNSD/dataset.zip" + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--url", default=DEFAULT_URL, help="FUNSD dataset.zip URL") + ap.add_argument("--force", action="store_true", help="re-download even if present") + args = ap.parse_args() + + dest = config.FUNSD_ROOT.parent # data/raw/funsd (the zip carries dataset/) + if config.FUNSD_TRAIN.is_dir() and config.FUNSD_TEST.is_dir() and not args.force: + print(f"FUNSD already present at {config.FUNSD_ROOT} (use --force to re-download)") + return + + dest.mkdir(parents=True, exist_ok=True) + with tempfile.TemporaryDirectory() as tmp: + zip_path = Path(tmp) / "funsd.zip" + print(f"downloading {args.url} ...") + try: + urllib.request.urlretrieve(args.url, zip_path) + except Exception as exc: # network/host failure -> point at the manual path + raise SystemExit( + f"download failed: {exc}\n" + f" fetch dataset.zip manually and unzip into {dest} so that\n" + f" {config.FUNSD_TRAIN} exists, then re-run scripts/evaluate_funsd.py") + print(f"extracting -> {dest}") + with zipfile.ZipFile(zip_path) as zf: + zf.extractall(dest) + + if not config.FUNSD_TRAIN.is_dir(): + raise SystemExit( + f"extracted, but {config.FUNSD_TRAIN} is missing - the archive layout may differ; " + f"check {dest} and move annotations so config.FUNSD_TRAIN/TEST resolve.") + n_train = len(list(config.FUNSD_TRAIN.glob("*.json"))) + n_test = len(list(config.FUNSD_TEST.glob("*.json"))) + print(f"ready: {n_train} train + {n_test} test annotation files under {config.FUNSD_ROOT}") + + +if __name__ == "__main__": + main() diff --git a/src/config.py b/src/config.py index 9b97cc1..8d2bc96 100644 --- a/src/config.py +++ b/src/config.py @@ -56,6 +56,7 @@ def in_colab() -> bool: CHUNKS = RAG_INDEX / "chunks" # serialized table chunks per (text_source, serialization) QA_DIR = OUTPUT_ROOT / "qa" # generated + merged QA sets (on Drive) EVALUATION = OUTPUT_ROOT / "evaluation" +FUNSD_OUTPUT = OUTPUT_ROOT / "funsd" # Phase 3: relation-linking artifacts (kept separate, P4) FAILURE_LOGS = OUTPUT_ROOT / "failure_logs" MANIFESTS = OUTPUT_ROOT / "manifests" FIGURES = OUTPUT_ROOT / "figures" @@ -66,6 +67,13 @@ def in_colab() -> bool: # is version-controlled and travels with git pull. QA_MANUAL_SEED = ROOT / "qa" / "qa_manual_seed.jsonl" +# Phase 3 FUNSD dataset (raw annotation JSON; gitignored under data/, fetched by +# scripts/fetch_funsd.py). V1 is annotation-only: the JSON carries entity text/bbox/label +# and the GT linking pairs, so no image pixels are loaded. 149 train + 50 test = 199 forms. +FUNSD_ROOT = DATA_ROOT / "raw" / "funsd" / "dataset" +FUNSD_TRAIN = FUNSD_ROOT / "training_data" / "annotations" # 149 forms +FUNSD_TEST = FUNSD_ROOT / "testing_data" / "annotations" # 50 forms + # Model IDs (DESIGN_SPEC sections 4.2 and 7; PLAN section 0). TATR_STRUCTURE_MODEL = "microsoft/table-transformer-structure-recognition-v1.1-fin" TATR_DETECTION_MODEL = "microsoft/table-transformer-detection" diff --git a/src/eval_funsd.py b/src/eval_funsd.py new file mode 100644 index 0000000..56f9768 --- /dev/null +++ b/src/eval_funsd.py @@ -0,0 +1,69 @@ +"""FUNSD relation metrics (Phase 3, V1). + +Custom set-based precision/recall/F1 over predicted-link vs GT-link sets - no sklearn (P/R/F1 +is pure set arithmetic). Split so the metric stays pure (no predictor inside) and the +form-runner is separate: + +- prf1: one (pred, gold) set pair -> P/R/F1. +- evaluate_pairs: micro P/R/F1 over many prebuilt (pred, gold) pairs (trivially unit-tested). +- evaluate_forms: runs the predictor over forms for a scope, then delegates to evaluate_pairs. + +Scopes (see plans/phase3-funsd-relations.md): +- "qa": predicted directed (question_id, answer_id) vs qa_gold_links (primary). +- "all": the same QA predictions cast to undirected frozensets vs all_gold_links - a coverage + diagnostic ("how many of ALL GT links does the QA-only heuristic recover"), not a second + predictor. +""" + +from __future__ import annotations + +from src.funsd_extraction import ( + HeuristicParams, + FunsdForm, + all_gold_links, + predict_qa_links, + qa_gold_links, +) + + +def _prf(tp: int, n_pred: int, n_gold: int) -> dict: + precision = tp / n_pred if n_pred else 0.0 + recall = tp / n_gold if n_gold else 0.0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0 + return {"tp": tp, "n_pred": n_pred, "n_gold": n_gold, + "precision": precision, "recall": recall, "f1": f1} + + +def prf1(pred: set, gold: set) -> dict: + """P/R/F1 for a single (pred, gold) set pair. Empty pred or gold scores 0 on that side.""" + return _prf(len(pred & gold), len(pred), len(gold)) + + +def evaluate_pairs(per_form: list[tuple[set, set]]) -> dict: + """Micro P/R/F1 over prebuilt (pred, gold) pairs: sum tp/|pred|/|gold| across forms, then + compute once. Pure - no predictor or params.""" + tp = sum(len(p & g) for p, g in per_form) + n_pred = sum(len(p) for p, _ in per_form) + n_gold = sum(len(g) for _, g in per_form) + out = _prf(tp, n_pred, n_gold) + out["num_forms"] = len(per_form) + return out + + +def evaluate_forms(forms: list[FunsdForm], scope: str, + params: HeuristicParams = HeuristicParams()) -> dict: + """Run the heuristic over forms and score it for a scope ("qa" or "all").""" + if scope not in ("qa", "all"): + raise ValueError(f"unknown scope: {scope!r} (use 'qa' or 'all')") + + per_form: list[tuple[set, set]] = [] + for form in forms: + pred = predict_qa_links(form, params) + if scope == "qa": + per_form.append((pred, qa_gold_links(form))) + else: # "all": score QA predictions as undirected pairs against every GT link + per_form.append(({frozenset(pair) for pair in pred}, all_gold_links(form))) + + out = evaluate_pairs(per_form) + out["scope"] = scope + return out diff --git a/src/funsd_extraction.py b/src/funsd_extraction.py new file mode 100644 index 0000000..e78564b --- /dev/null +++ b/src/funsd_extraction.py @@ -0,0 +1,210 @@ +"""FUNSD relation-linking baseline (Phase 3, V1). + +Annotation-only and CPU-only: the FUNSD annotation JSON already carries each entity's text, +bbox, label ({question, answer, header, other}), and the GT `linking` pairs, so nothing here +loads image pixels. The module provides the data contract (parse + normalize + dedupe GT +links), the two link scopes used for reporting, and a deterministic spatial predictor. + +Predictor: per-answer argmax + distance gate. Each `answer` scores every `question` +candidate (a same-row right-side relation or a below relation), the candidate set is filtered +by a distance gate, and the answer is linked to its single best question if that score clears +a floor. Distances are normalized by the form's median entity height so one set of thresholds +works across differently-scaled scans. + +GT links are stored undirected (FUNSD records a link on both endpoints); `qa_gold_links` +canonicalizes question+answer pairs to a directed (question_id, answer_id), `all_gold_links` +keeps the full undirected set. See plans/phase3-funsd-relations.md. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from statistics import median +from typing import TypedDict + +QUESTION = "question" +ANSWER = "answer" + + +class FunsdEntity(TypedDict): + """A FUNSD form entity. box is [x0, y0, x1, y1] in image pixels (top-left origin).""" + + id: int + label: str + text: str + box: list[float] + + +class FunsdForm(TypedDict): + """A parsed form. gold_links is the deduped undirected GT link set (cast to list only + when serializing to JSON).""" + + form_id: str + entities: list[FunsdEntity] + gold_links: set[frozenset[int]] + + +@dataclass(frozen=True) +class HeuristicParams: + """Tunable surface for the relation heuristic. Defaults are a-priori; any fitting is done + on the FUNSD train split only (never on the reported test set). Distances are in units of + the form's median entity height. + + Two clearly-separated knobs, not one fuzzy gate: + - max_distance_units: the distance gate that filters candidates too far to be plausible. + - min_score: the floor the per-answer argmax winner must clear to be emitted. + """ + + right_base: float = 1.0 # base score for a same-row right-side answer + below_base: float = 0.7 # base score for a below answer + right_band_tol: float = 0.7 # vertical-center tolerance for "same row" (H units) + below_align_tol: float = 1.0 # left-edge tolerance for "below" alignment (H units) + align_boost: float = 0.5 # reward for tighter band / left-edge alignment + dist_penalty: float = 0.3 # score lost per median-height of gap + max_distance_units: float = 8.0 # distance gate: reject candidates beyond this (H units) + min_score: float = 0.0 # score floor on the chosen link + + +# --- parsing --- + + +def parse_funsd_form(data: dict, form_id: str) -> FunsdForm: + """Build a FunsdForm from the FUNSD JSON shape ({"form": [ {id, label, box, ...}, ... ]}). + + Links are collected across all entities and deduped to undirected frozensets; self-links + and links referencing a missing id are dropped. + """ + raw = data.get("form", []) + entities: list[FunsdEntity] = [] + ids: set[int] = set() + for e in raw: + eid = int(e["id"]) + entities.append(FunsdEntity( + id=eid, + label=str(e.get("label", "")), + text=str(e.get("text", "")), + box=[float(v) for v in e["box"]], + )) + ids.add(eid) + + gold: set[frozenset[int]] = set() + for e in raw: + for pair in e.get("linking", []): + a, b = int(pair[0]), int(pair[1]) + if a == b or a not in ids or b not in ids: + continue + gold.add(frozenset((a, b))) + + return FunsdForm(form_id=form_id, entities=entities, gold_links=gold) + + +def parse_funsd_json(path) -> FunsdForm: + path = Path(path) + data = json.loads(path.read_text(encoding="utf-8")) + return parse_funsd_form(data, path.stem) + + +def load_funsd_split(annotations_dir) -> list[FunsdForm]: + """Parse every *.json in a FUNSD annotations directory (sorted for determinism).""" + return [parse_funsd_json(p) for p in sorted(Path(annotations_dir).glob("*.json"))] + + +# --- GT link scopes --- + + +def _labels(form: FunsdForm) -> dict[int, str]: + return {e["id"]: e["label"] for e in form["entities"]} + + +def qa_gold_links(form: FunsdForm) -> set[tuple[int, int]]: + """Question+answer GT links as directed (question_id, answer_id) pairs (primary scope).""" + label = _labels(form) + out: set[tuple[int, int]] = set() + for pair in form["gold_links"]: + a, b = tuple(pair) + la, lb = label[a], label[b] + if {la, lb} == {QUESTION, ANSWER}: + q, ans = (a, b) if la == QUESTION else (b, a) + out.add((q, ans)) + return out + + +def all_gold_links(form: FunsdForm) -> set[frozenset[int]]: + """The full deduped undirected GT link set (secondary coverage scope).""" + return set(form["gold_links"]) + + +# --- heuristic predictor --- + + +def median_entity_height(entities: list[FunsdEntity]) -> float: + heights = [e["box"][3] - e["box"][1] for e in entities] + return float(median(heights)) if heights else 1.0 + + +def _overlap_1d(a0: float, a1: float, b0: float, b1: float) -> float: + return max(0.0, min(a1, b1) - max(a0, b0)) + + +def _score(q: FunsdEntity, a: FunsdEntity, h: float, params: HeuristicParams) -> float | None: + """Best score for linking answer a to question q, or None if a is not a valid candidate. + + Right-side and below relations compete in the same max (not a strict fallback): whichever + geometric relation gives the higher score wins, after the distance gate. + """ + qx0, qy0, qx1, qy1 = q["box"] + ax0, ay0, ax1, ay1 = a["box"] + qcx, qcy = (qx0 + qx1) / 2, (qy0 + qy1) / 2 + acx, acy = (ax0 + ax1) / 2, (ay0 + ay1) / 2 + h = h if h > 1e-6 else 1e-6 + + candidates: list[tuple[float, float]] = [] # (score, distance_units) + + # same-row right-side: A vertically aligned with Q and to its right + band = abs(acy - qcy) / h + if acx > qcx and band <= params.right_band_tol: + hgap = max(0.0, ax0 - qx1) / h + score = (params.right_base + + params.align_boost * (params.right_band_tol - band) + - params.dist_penalty * hgap) + candidates.append((score, hgap)) + + # below: A under Q, horizontally overlapping or left-aligned + left_off = abs(ax0 - qx0) / h + if acy > qcy and (_overlap_1d(qx0, qx1, ax0, ax1) > 0 or left_off <= params.below_align_tol): + vgap = max(0.0, ay0 - qy1) / h + score = (params.below_base + + params.align_boost * max(0.0, params.below_align_tol - left_off) + - params.dist_penalty * vgap) + candidates.append((score, vgap)) + + valid = [s for s, d in candidates if d <= params.max_distance_units] + return max(valid) if valid else None + + +def predict_qa_links(form: FunsdForm, + params: HeuristicParams = HeuristicParams()) -> set[tuple[int, int]]: + """Per-answer argmax: each answer links to its single best-scoring question above the gate + and the score floor. Returns directed (question_id, answer_id) pairs.""" + entities = form["entities"] + questions = [e for e in entities if e["label"] == QUESTION] + answers = [e for e in entities if e["label"] == ANSWER] + h = median_entity_height(entities) + + links: set[tuple[int, int]] = set() + for a in answers: + best_q: FunsdEntity | None = None + best_score: float | None = None + for q in questions: + s = _score(q, a, h, params) + if s is None: + continue + # deterministic: higher score wins, ties break to the lower question id + if (best_score is None or s > best_score + or (s == best_score and q["id"] < best_q["id"])): + best_score, best_q = s, q + if best_q is not None and best_score >= params.min_score: + links.add((best_q["id"], a["id"])) + return links diff --git a/tests/test_funsd_relations.py b/tests/test_funsd_relations.py new file mode 100644 index 0000000..62a3f52 --- /dev/null +++ b/tests/test_funsd_relations.py @@ -0,0 +1,216 @@ +"""FUNSD relation-linking tests (CPU, synthetic) - Phase 3. + +Annotation-only baseline: forms are built from inline dicts in the same shape FUNSD's +annotation JSON uses ({"form": [ {id, label, box, text, linking}, ... ]}); no raw dataset +and no image pixels are touched. Covers the parser, link dedupe/scoping, the per-answer +argmax + distance-gate heuristic, and the set-based P/R/F1 metric. +""" + +import json + +from dataclasses import replace + +from src.funsd_extraction import ( + HeuristicParams, + all_gold_links, + median_entity_height, + parse_funsd_form, + parse_funsd_json, + predict_qa_links, + qa_gold_links, +) +from src.eval_funsd import evaluate_forms, evaluate_pairs, prf1 + + +def _ent(eid, label, box, *, text="", linking=None): + return {"id": eid, "label": label, "box": list(box), "text": text, + "linking": list(linking or [])} + + +def _form(entities, form_id="f0"): + return parse_funsd_form({"form": entities}, form_id) + + +# --- parsing --- + + +def test_parse_form_entities_and_links(): + form = _form([ + _ent(0, "question", [0, 0, 50, 20], text="Name", linking=[[0, 1]]), + _ent(1, "answer", [60, 0, 120, 20], text="Bob", linking=[[0, 1]]), + ]) + assert form["form_id"] == "f0" + assert len(form["entities"]) == 2 + assert form["entities"][0]["label"] == "question" + assert form["entities"][0]["box"] == [0.0, 0.0, 50.0, 20.0] + assert form["gold_links"] == {frozenset((0, 1))} + + +def test_parse_json_reads_file(tmp_path): + p = tmp_path / "form.json" + p.write_text(json.dumps({"form": [ + _ent(0, "question", [0, 0, 1, 1], linking=[[0, 1]]), + _ent(1, "answer", [2, 0, 3, 1], linking=[[0, 1]]), + ]}), encoding="utf-8") + form = parse_funsd_json(p) + assert form["form_id"] == "form" + assert form["gold_links"] == {frozenset((0, 1))} + + +def test_links_dedupe_bidirectional_and_drop_invalid(): + # link recorded on both ends, plus a self-link and a dangling id -> one clean pair + form = _form([ + _ent(0, "question", [0, 0, 50, 20], linking=[[0, 1], [0, 0]]), + _ent(1, "answer", [60, 0, 120, 20], linking=[[1, 0], [1, 99]]), + ]) + assert form["gold_links"] == {frozenset((0, 1))} + + +# --- link scopes --- + + +def test_qa_link_filter_keeps_only_question_answer(): + form = _form([ + _ent(0, "question", [0, 0, 50, 20], linking=[[0, 1], [0, 2]]), + _ent(1, "answer", [60, 0, 120, 20], linking=[[0, 1]]), + _ent(2, "question", [0, 30, 50, 50], linking=[[0, 2]]), # question-question + _ent(3, "header", [0, 60, 50, 80], linking=[[3, 0]]), # header-question + ]) + assert qa_gold_links(form) == {(0, 1)} # directed question -> answer + assert all_gold_links(form) == { + frozenset((0, 1)), frozenset((0, 2)), frozenset((0, 3))} + + +def test_qa_link_canonicalizes_direction_regardless_of_record_order(): + # link stored answer-first; qa scope still emits (question_id, answer_id) + form = _form([ + _ent(0, "answer", [60, 0, 120, 20], linking=[[0, 1]]), + _ent(1, "question", [0, 0, 50, 20], linking=[[0, 1]]), + ]) + assert qa_gold_links(form) == {(1, 0)} + + +def test_header_question_link_excluded_from_qa_scope(): + form = _form([ + _ent(0, "header", [0, 0, 50, 20], linking=[[0, 1]]), + _ent(1, "question", [0, 30, 50, 50], linking=[[0, 1]]), + ]) + assert qa_gold_links(form) == set() + assert all_gold_links(form) == {frozenset((0, 1))} + + +# --- metrics --- + + +def test_prf1_basic(): + m = prf1({(1, 2), (3, 4)}, {(1, 2), (5, 6)}) + assert m["tp"] == 1 + assert m["precision"] == 0.5 + assert m["recall"] == 0.5 + assert m["f1"] == 0.5 + + +def test_prf1_empty_pred_is_zero(): + m = prf1(set(), {(1, 2)}) + assert m["precision"] == 0.0 and m["recall"] == 0.0 and m["f1"] == 0.0 + + +def test_evaluate_pairs_micro_aggregation(): + per_form = [ + ({(1, 2)}, {(1, 2), (3, 4)}), # tp 1, pred 1, gold 2 + (set(), {(5, 6)}), # tp 0, pred 0, gold 1 + ({(7, 8)}, {(7, 8)}), # tp 1, pred 1, gold 1 + ] + m = evaluate_pairs(per_form) + assert m["num_forms"] == 3 + assert m["tp"] == 2 and m["n_pred"] == 2 and m["n_gold"] == 4 + assert m["precision"] == 1.0 # 2 / 2 + assert m["recall"] == 0.5 # 2 / 4 + assert round(m["f1"], 4) == round(2 * 1.0 * 0.5 / 1.5, 4) + + +# --- heuristic predictor --- + + +def test_same_row_right_side_answer_is_linked(): + form = _form([ + _ent(0, "question", [0, 0, 50, 20]), + _ent(1, "answer", [60, 0, 120, 20]), + ]) + assert predict_qa_links(form) == {(0, 1)} + + +def test_below_candidate_can_win_against_a_valid_right_candidate(): + # answer A has both a valid same-row question (far left) and a question directly above; + # the closer "below" question wins the per-answer argmax (not a no-right-side fallback) + form = _form([ + _ent(0, "question", [0, 100, 40, 120]), # same row, far to A's left + _ent(1, "question", [100, 70, 150, 90]), # directly above A, close + _ent(2, "answer", [100, 100, 150, 120]), + ]) + assert predict_qa_links(form) == {(1, 2)} + + +def test_other_label_is_never_linked(): + form = _form([ + _ent(0, "question", [0, 0, 50, 20]), + _ent(1, "other", [60, 0, 120, 20]), # sits where an answer would, but is "other" + _ent(2, "header", [0, 30, 50, 50]), + ]) + assert predict_qa_links(form) == set() + + +def test_per_answer_argmax_links_only_the_nearer_question(): + form = _form([ + _ent(0, "question", [0, 0, 40, 20]), # far left + _ent(1, "question", [60, 0, 90, 20]), # nearer to A + _ent(2, "answer", [100, 0, 150, 20]), + ]) + assert predict_qa_links(form) == {(1, 2)} + + +def test_distance_gate_drops_far_answer(): + # right-side candidate exists and scores above min_score, but its normalized distance + # exceeds max_distance_units -> the gate (not the score floor) drops it + form = _form([ + _ent(0, "question", [0, 0, 50, 20]), + _ent(1, "answer", [110, 0, 160, 20]), # hgap = (110-50)/20 = 3.0 median-heights + ]) + tight = replace(HeuristicParams(), max_distance_units=2.0) + assert predict_qa_links(form, tight) == set() + assert predict_qa_links(form) == {(0, 1)} # default gate (8.0) keeps it + + +def test_median_entity_height(): + form = _form([ + _ent(0, "question", [0, 0, 10, 10]), # h 10 + _ent(1, "answer", [0, 0, 10, 30]), # h 30 + _ent(2, "answer", [0, 0, 10, 20]), # h 20 + ]) + assert median_entity_height(form["entities"]) == 20.0 + + +# --- form-level evaluation (predictor + scope, end to end on a synthetic form) --- + + +def test_evaluate_forms_qa_scope_perfect_form(): + forms = [_form([ + _ent(0, "question", [0, 0, 50, 20], linking=[[0, 1]]), + _ent(1, "answer", [60, 0, 120, 20], linking=[[0, 1]]), + ])] + m = evaluate_forms(forms, "qa") + assert m["scope"] == "qa" + assert m["precision"] == 1.0 and m["recall"] == 1.0 and m["f1"] == 1.0 + + +def test_evaluate_forms_all_scope_counts_uncovered_links(): + # one QA link (predictable) + one header-question link the QA-only predictor cannot cover + forms = [_form([ + _ent(0, "question", [0, 0, 50, 20], linking=[[0, 1], [2, 0]]), + _ent(1, "answer", [60, 0, 120, 20], linking=[[0, 1]]), + _ent(2, "header", [0, 30, 50, 50], linking=[[2, 0]]), + ])] + m = evaluate_forms(forms, "all") + assert m["scope"] == "all" + assert m["n_gold"] == 2 and m["tp"] == 1 # only the QA link is recovered + assert m["recall"] == 0.5