diff --git a/notebooks/04_phase2_layout.ipynb b/notebooks/04_phase2_layout.ipynb index c687b84..1edd1c4 100644 --- a/notebooks/04_phase2_layout.ipynb +++ b/notebooks/04_phase2_layout.ipynb @@ -144,7 +144,7 @@ { "cell_type": "code", "id": "47c2a078", - "source": "!python scripts/run_layout_batch.py --seed 7 --n 20 --require-table-gt --primary-threshold 0.3 --table-threshold 0.5", + "source": "!python scripts/run_layout_batch.py --seed 7 --n 20 --require-table-gt --primary-threshold 0.3", "metadata": {}, "execution_count": null, "outputs": [] @@ -152,7 +152,7 @@ { "cell_type": "code", "id": "3bc88c2c", - "source": "# Manifest preview + spot-check one crop\nimport pandas as pd\nfrom pathlib import Path\nfrom src import config\nfrom IPython.display import display, Image as IPImage\n\nmanifest_path = config.LAYOUT_OUTPUT / \"manifest.csv\"\ndf = pd.read_csv(manifest_path)\nprint(df[[\"page_id\", \"status\", \"gt_tables\", \"num_regions\", \"num_tables\", \"num_cropped\", \"fallback_used\"]].to_string(index=False))\nprint(f\"\\nprocessed={df['status'].eq('processed').sum()} failed={df['status'].eq('failed').sum()}\")\nprint(f\"gt_table pages={df['gt_tables'].gt(0).sum()} total crops={df['num_cropped'].sum()} fallback pages={df['fallback_used'].sum()}\")\n\n# Display first crop found\ncrops_dir = config.LAYOUT_OUTPUT / \"crops\"\nfirst_crop = next(crops_dir.glob(\"*.png\"), None)\nif first_crop:\n print(f\"\\nspot-check: {first_crop.name}\")\n display(IPImage(str(first_crop), width=600))\nelse:\n print(\"no crops written\")", + "source": "# Manifest preview + spot-check one crop\nimport pandas as pd\nfrom pathlib import Path\nfrom src import config\nfrom IPython.display import display, Image as IPImage\n\nmanifest_path = config.LAYOUT_OUTPUT / \"manifest.csv\"\ndf = pd.read_csv(manifest_path)\nprint(df[[\"page_id\", \"status\", \"gt_tables\", \"num_regions\", \"num_tables\", \"num_cropped\", \"fallback_used\"]].to_string(index=False))\nprint(f\"\\nprocessed={df['status'].eq('processed').sum()} failed={df['status'].eq('failed').sum()}\")\nprint(f\"gt_table pages={df['gt_tables'].gt(0).sum()} total crops={df['num_cropped'].sum()} fallback pages={df['fallback_used'].sum()}\")\n\n# Only show crops from this run's pages (avoids stale artifact confusion)\ncrops_dir = config.LAYOUT_OUTPUT / \"crops\"\nrun_pages = set(df[\"page_id\"])\nrun_crops = sorted(\n f for f in crops_dir.glob(\"*.png\")\n if f.stem.rsplit(\"_table_\", 1)[0] in run_pages\n)\nfirst_crop = run_crops[0] if run_crops else None\nif first_crop:\n print(f\"\\nspot-check: {first_crop.name}\")\n display(IPImage(str(first_crop), width=600))\nelse:\n print(\"no crops written for this run\")", "metadata": {}, "execution_count": null, "outputs": [] @@ -166,7 +166,7 @@ { "cell_type": "code", "id": "3b8b19a3", - "source": "!python scripts/eval_layout_iou.py --seed 7 --n 20 --require-table-gt --primary-threshold 0.3 --table-threshold 0.5", + "source": "!python scripts/eval_layout_iou.py --seed 7 --n 20 --require-table-gt --primary-threshold 0.3", "metadata": {}, "execution_count": null, "outputs": [] @@ -174,7 +174,7 @@ { "cell_type": "code", "id": "4b348fa6", - "source": "# Diagnostic CSV preview (sorted by best_iou_final desc)\nimport pandas as pd\nfrom src import config\n\ndf = pd.read_csv(config.LAYOUT_OUTPUT / \"diagnostic.csv\")\ncols = [\"page_id\", \"gt_tables\", \"primary_tables\", \"primary_max_score\",\n \"fallback_used\", \"best_iou_primary\", \"best_iou_fallback\", \"best_iou_final\"]\nprint(df[cols].sort_values(\"best_iou_final\", ascending=False).to_string(index=False))", + "source": "# Diagnostic CSV preview (sorted by best_iou_crop desc)\nimport pandas as pd\nfrom src import config\n\n# --require-table-gt writes diagnostic_pos.csv; fall back to diagnostic.csv\n_diag = config.LAYOUT_OUTPUT / \"diagnostic_pos.csv\"\nif not _diag.exists():\n _diag = config.LAYOUT_OUTPUT / \"diagnostic.csv\"\ndf = pd.read_csv(_diag)\ncols = [\"page_id\", \"gt_tables\", \"num_crop_tables\", \"matched_50\", \"matched_75\",\n \"primary_max_score\", \"fallback_used\",\n \"best_iou_primary\", \"best_iou_crop\"]\nprint(df[cols].sort_values(\"best_iou_crop\", ascending=False).to_string(index=False))\nhas_gt = df[df[\"gt_tables\"] > 0]\nprint(f\"\\ngt_total={has_gt['gt_tables'].sum()} crops={has_gt['num_crop_tables'].sum()}\"\n f\" matched@0.5={has_gt['matched_50'].sum()} matched@0.75={has_gt['matched_75'].sum()}\"\n f\" mean_crop_iou={has_gt['best_iou_crop'].mean():.3f}\")", "metadata": {}, "execution_count": null, "outputs": [] @@ -188,7 +188,7 @@ { "cell_type": "code", "id": "95b7eab2", - "source": "!python scripts/eval_layout_iou.py --seed 7 --n 20 --exclude-table-gt --primary-threshold 0.3 --table-threshold 0.5", + "source": "!python scripts/eval_layout_iou.py --seed 7 --n 20 --exclude-table-gt --primary-threshold 0.3", "metadata": {}, "execution_count": null, "outputs": [] @@ -196,7 +196,7 @@ { "cell_type": "markdown", "id": "3adb928e", - "source": "## Step 5c - Retune: confirm threshold=0.30 improvement\n\nQ3 simulation says lowering `table_threshold` from 0.50 → 0.30 should raise\nmean IoU 0.823 → 0.925 (+10%). This is because 3 of 4 fallback triggers at\nthresh=0.50 were false-negatives: primary scored 0.34–0.45 but had IoU ~0.963.\nAt thresh=0.30 only val_000378 (true miss) goes to TATR fallback.\n\nRun the same positive diagnostic with `--table-threshold 0.30` to measure the\nactual (not simulated) improvement. The column `best_iou_final` at thresh=0.30\nshould be higher than at thresh=0.50 for the 3 reclaimed pages.", + "source": "## Step 5c - Retune: confirm threshold=0.30 improvement\n\nQ3 simulation (new fallback rule: fires only when `primary_tables >= 1`) predicts\nlowering `table_threshold` from 0.50 → 0.30 reduces fallback pages from 3 → 1\nand raises `mean_iou_crop_sim` (0.963 reclaimed for the 2 pages where primary\nhad a low-score but high-IoU box).\n\nWatch the two key numbers in the output:\n- `mean best_iou_crop` at thresh=0.30 vs thresh=0.50\n- `Fallback used` count (should drop from 3 to 1)\n\nThe 2 reclaimed pages are val_000238 (primary score 0.44, IoU 0.954) and\nval_001347 (primary max 0.45, IoU 0.949). val_004383 (score 0.335) stays in\nfallback territory at both thresholds.", "metadata": {} }, { @@ -206,6 +206,102 @@ "metadata": {}, "execution_count": null, "outputs": [] + }, + { + "cell_type": "markdown", + "id": "8e75e001", + "source": "## Step 5d - Spot-check val_005241: dedup collapse (primary IoU 0.944 → crop IoU 0.610)\n\n`val_005241` has **2 GT tables** that sit close together. Primary found 2 boxes that\noverlap each other above `dedup_iou=0.5`, so dedup kept only the higher-scoring one.\nThat surviving box aligns with GT table 2 (IoU ~0.61), not GT table 1 (IoU ~0.94).\nThe box that would have given 0.94 was dropped as a \"duplicate.\"\n\nThis cell re-runs the primary detector directly (no dedup) and shows per-box IoU against\neach GT so the collapse is visible. Also displays any crops saved by Step 4.\n\n**Run Step 4 first** so the regions JSON and crops exist at `config.LAYOUT_OUTPUT`.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "baefa041", + "source": "import json\nfrom PIL import Image\nfrom datasets import load_dataset\nfrom IPython.display import display, Image as IPImage\nfrom src import config\nfrom src.bbox_utils import iou, xywh_to_xyxy\nfrom src.layout_detector import build_layout_detector\nfrom src.layout_parsing import TABLE_LABEL\n\nPAGE_ID = \"val_005241\"\nORIG_IDX = 5241\n\n# GT boxes\nds_val = load_dataset(\"docling-project/DocLayNet-v1.1\", split=\"val\")\nex = ds_val[ORIG_IDX]\nbboxes = ex.get(\"bboxes\", ex.get(\"bbox\", []))\ncats = ex.get(\"category_id\", [])\ngt_boxes = [xywh_to_xyxy(tuple(b)) for cat, b in zip(cats, bboxes) if cat == 9]\nprint(f\"GT tables ({len(gt_boxes)}):\")\nfor i, b in enumerate(gt_boxes):\n print(f\" GT[{i}]: {[round(c) for c in b]}\")\n\n# Primary detector (no dedup)\ntry:\n layout_det\nexcept NameError:\n layout_det = build_layout_detector(config.LAYOUT_MODEL, threshold=0.3)\n\nimg = ex[\"image\"].convert(\"RGB\")\nprimary_tables = [r for r in layout_det(img) if r.label == TABLE_LABEL]\nprint(f\"\\nPrimary-alone table regions ({len(primary_tables)}) — before dedup:\")\nfor i, r in enumerate(primary_tables):\n ious_vs_gt = [round(iou(r.box, g), 4) for g in gt_boxes]\n ious_vs_primary = [round(iou(r.box, r2.box), 4) for j, r2 in enumerate(primary_tables) if j != i]\n print(f\" P[{i}] score={r.score:.4f} box={[round(c) for c in r.box]}\")\n print(f\" IoU/GT={ious_vs_gt} IoU/otherPrimary={ious_vs_primary}\")\n\n# Final regions from batch JSON\nregions_path = config.LAYOUT_OUTPUT / \"regions\" / f\"{PAGE_ID}.json\"\nif regions_path.exists():\n regions = json.loads(regions_path.read_text())\n final_tables = [r for r in regions if r[\"label\"] == \"table\"]\n print(f\"\\nFinal table regions from batch JSON ({len(final_tables)}) — after dedup:\")\n for r in final_tables:\n box = tuple(r[\"box\"])\n ious_vs_gt = [round(iou(box, g), 4) for g in gt_boxes]\n print(f\" score={r['score']:.4f} source={r['source']} box={[round(c) for c in box]}\")\n print(f\" IoU/GT={ious_vs_gt}\")\nelse:\n print(f\"\\n[warn] {regions_path} not found — run Step 4 first\")\n\n# Crops\ncrops_dir = config.LAYOUT_OUTPUT / \"crops\"\ncrop_files = sorted(crops_dir.glob(f\"{PAGE_ID}_table_*.png\"))\nprint(f\"\\nCrops saved by Step 4 ({len(crop_files)}):\")\nfor crop_path in crop_files:\n print(f\" {crop_path.name}\")\n display(IPImage(str(crop_path), width=600))\nif not crop_files:\n print(\" none — re-run Step 4 (default thresh=0.3 will now save them)\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "a7552431", + "source": "## Step 5e - dedup sensitivity: test dedup-iou=0.70\n\n`val_005241`: P[0] score=0.484 IoU=0.610, P[1] score=0.342 IoU=0.944.\nTheir mutual IoU is ~0.600. At `dedup_iou=0.50`, NMS collapses them (keeps P[0] by score);\nat `dedup_iou=0.70`, both survive and both GTs get covered.\n\nExpected effect: `best_iou_crop` for val_005241 rises toward 0.944;\nthe trade-off is potentially one extra crop (two overlapping boxes for the same region).\nWatch `mean best_iou_crop` across all 20 pages — it should improve if val_005241 is\nthe main outlier and other pages are unaffected by the looser dedup.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "389a86df", + "source": "!python scripts/eval_layout_iou.py --seed 7 --n 20 --require-table-gt --primary-threshold 0.3 --dedup-iou 0.7", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "d7d14b1a", + "source": "## Step 6 - MVP evaluation (seed=42, n=200)\n\nFinal calibration: `table_threshold=0.30`, `dedup_iou=0.70` (both now defaults).\nThis is the gate run before recording Phase 2 layout detection as complete.\n\n**Pass criteria (approximate):**\n- `mean best_iou_crop` ≥ 0.88 on positive set\n- FP crop rate ≤ 15% on negative set\n- `failure` count = 0 or negligible\n\nRun 6a (batch), then 6b (positive IoU), then 6c (false-positive).\nRe-run the Step 4 manifest preview cell after 6a to see n=200 stats.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "f497cccb", + "source": "# Step 6a - MVP batch (seed=42, n=200)\n!python scripts/run_layout_batch.py --seed 42 --n 200 --require-table-gt --primary-threshold 0.3", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "56f9c060", + "source": "# Step 6b - positive IoU diagnostic (seed=42, n=200)\n!python scripts/eval_layout_iou.py --seed 42 --n 200 --require-table-gt --primary-threshold 0.3", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "5292729f", + "source": "# Step 6c - false-positive diagnostic (seed=42, n=200)\n!python scripts/eval_layout_iou.py --seed 42 --n 200 --exclude-table-gt --primary-threshold 0.3", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "976edd57", + "source": "## Step 7 - End-to-end: Phase 2 crop → TATR structure recognition\n\nConfirms the Phase 2 crops are compatible with the Phase 1 TATR structure model.\nFor each selected crop: runs inference → `normalize_tatr_prediction` → `validate_grid_geometry`.\nPrints `rows`, `cols`, `cells`, `valid`, and failure reasons per crop. Writes `smoke_structure.csv`.\n\n**Step 7 (baseline):** no band dedup — measures raw TATR output quality on DocLayNet crops.\n**Step 7c (dedup):** `--dedup-bands` applies 1-D NMS to overlapping row/col bands before normalize.\n\n`--n 50 --seed 42` samples 50 crops from the Step 6a batch (286 available).", + "metadata": {} + }, + { + "cell_type": "code", + "id": "79cf79c2", + "source": "!python scripts/smoke_structure.py --n 50 --seed 42", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "32258829", + "source": "# Step 7b - smoke_structure.csv: inspect failure reasons for WARN crops\nimport pandas as pd\nfrom src import config\n\ndf = pd.read_csv(config.LAYOUT_OUTPUT / \"smoke_structure.csv\")\nprint(f\"OK: {df['valid'].sum()} WARN: {(~df['valid']).sum()} total: {len(df)}\")\nprint()\nwarn = df[~df[\"valid\"]].copy()\nif len(warn):\n print(\"WARN crops:\")\n print(warn[[\"crop\", \"rows\", \"cols\", \"cells\", \"failure_reasons\"]].to_string(index=False))\n print()\n print(\"Failure reason counts:\")\n all_reasons = [r.strip() for reasons in warn[\"failure_reasons\"].dropna() for r in reasons.split(\";\") if r.strip()]\n from collections import Counter\n for reason, count in Counter(all_reasons).most_common():\n print(f\" {count:3d} {reason}\")\nelse:\n print(\"All crops valid.\")", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "4efa70e7", + "source": "# Step 7c - historical A/B reference only; dedup is now the default in normalize_tatr_prediction\n# Before dedup: 37 OK / 13 WARN (Step 7b)\n# After dedup: 50 OK / 0 WARN (this run, seed=42, n=50)\n# !python scripts/smoke_structure.py --n 50 --seed 42 --dedup-bands # flag no longer exists", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "id": "4faede6e", + "source": "# Step 7d - full crop smoke: all 286 crops, dedup now default\n# Result: 285 OK / 1 WARN (val_000670_table_1: rows=0 cols=4 -> no row boxes detected)\n# WARN rate 0.35%, well under <=5% gate. Phase 2 crop -> structure handoff: PASSED.\n!python scripts/smoke_structure.py --n 286 --seed 42", + "metadata": {}, + "execution_count": null, + "outputs": [] } ], "metadata": { diff --git a/scripts/eval_layout_iou.py b/scripts/eval_layout_iou.py index 9ccbc67..9c458f9 100644 --- a/scripts/eval_layout_iou.py +++ b/scripts/eval_layout_iou.py @@ -41,13 +41,18 @@ class _Row(NamedTuple): primary_max_score: float # fallback-alone fallback_tables: int - # detect_layout merged result - final_tables: int + # detect_layout merged result: candidates (all tables) vs cropped (score >= table_threshold) + num_candidate_tables: int + num_crop_tables: int fallback_used: bool - # IoU vs GT + # IoU vs GT: candidate = all detected, crop = score-filtered (matches batch runner) best_iou_primary: float best_iou_fallback: float - best_iou_final: float + best_iou_candidate: float + best_iou_crop: float + # table-level greedy matching (crops vs GT) + matched_50: int + matched_75: int def _best_iou(regions: list[Region], gt_boxes: list) -> float: @@ -56,6 +61,25 @@ def _best_iou(regions: list[Region], gt_boxes: list) -> float: return max(iou(r.box, g) for r in regions for g in gt_boxes) +def _greedy_match(pred_boxes: list, gt_boxes: list, threshold: float) -> int: + """Count GT tables matched at IoU >= threshold by greedy assignment (highest IoU first).""" + if not pred_boxes or not gt_boxes: + return 0 + pairs = sorted( + ((iou(p, g), pi, gi) for pi, p in enumerate(pred_boxes) for gi, g in enumerate(gt_boxes)), + reverse=True, + ) + matched_preds: set[int] = set() + matched_gts: set[int] = set() + for v, pi, gi in pairs: + if v < threshold: + break + if pi not in matched_preds and gi not in matched_gts: + matched_preds.add(pi) + matched_gts.add(gi) + return len(matched_gts) + + def _gt_table_boxes(ex: dict, img_w: int, img_h: int) -> list[tuple]: """Extract GT table boxes (xyxy pixel) from a DocLayNet dataset example. @@ -88,9 +112,9 @@ def parse_args() -> argparse.Namespace: p.add_argument("--n", type=int, default=20) p.add_argument("--out-dir", type=Path, default=None) p.add_argument("--primary-threshold", type=float, default=0.3) - p.add_argument("--table-threshold", type=float, default=0.5, + p.add_argument("--table-threshold", type=float, default=0.3, help="active threshold used for fallback trigger and final crop") - p.add_argument("--dedup-iou", type=float, default=0.5) + p.add_argument("--dedup-iou", type=float, default=0.7) p.add_argument("--require-table-gt", action="store_true", help="only sample pages with GT Table annotations (category_id 9)") p.add_argument("--exclude-table-gt", action="store_true", @@ -163,7 +187,11 @@ def main() -> None: ) elapsed = time.time() - t1 final_tables = [r for r in final_regions if r.label == TABLE_LABEL] + final_crop_tables = [r for r in final_tables if r.score >= args.table_threshold] fallback_used = any(r.source == "table_fallback" for r in final_regions) + crop_boxes = [r.box for r in final_crop_tables] + matched_50 = _greedy_match(crop_boxes, gt_boxes, 0.50) + matched_75 = _greedy_match(crop_boxes, gt_boxes, 0.75) row = _Row( page_id=page_id, @@ -171,23 +199,29 @@ def main() -> None: primary_tables=len(primary_tables), primary_max_score=round(primary_max_score, 4), fallback_tables=len(fallback_tables), - final_tables=len(final_tables), + num_candidate_tables=len(final_tables), + num_crop_tables=len(final_crop_tables), fallback_used=fallback_used, best_iou_primary=round(_best_iou(primary_tables, gt_boxes), 4), best_iou_fallback=round(_best_iou(fallback_tables, gt_boxes), 4), - best_iou_final=round(_best_iou(final_tables, gt_boxes), 4), + best_iou_candidate=round(_best_iou(final_tables, gt_boxes), 4), + best_iou_crop=round(_best_iou(final_crop_tables, gt_boxes), 4), + matched_50=matched_50, + matched_75=matched_75, ) rows.append(row) print( f"gt={row.gt_tables}" f" prim={row.primary_tables}(max={row.primary_max_score:.2f})" f" fb={row.fallback_tables}" - f" iou p/fb/fin={row.best_iou_primary:.2f}/{row.best_iou_fallback:.2f}/{row.best_iou_final:.2f}" + f" iou p/fb/cand/crop={row.best_iou_primary:.2f}/{row.best_iou_fallback:.2f}/{row.best_iou_candidate:.2f}/{row.best_iou_crop:.2f}" + f" m50={row.matched_50}/{row.gt_tables}" f" fb_used={row.fallback_used} {elapsed:.2f}s" ) - # Write CSV - diag_path = out_dir / "diagnostic.csv" + # Write CSV — mode-suffixed so positive/negative runs don't overwrite each other + mode_suffix = "_pos" if args.require_table_gt else ("_neg" if args.exclude_table_gt else "") + diag_path = out_dir / f"diagnostic{mode_suffix}.csv" with diag_path.open("w", newline="") as f: w = csv.DictWriter(f, fieldnames=list(_Row._fields)) w.writeheader() @@ -201,9 +235,10 @@ def main() -> None: print(f"\n{'='*60}") print(f"Pages with GT tables : {len(has_gt)} / {len(rows)}") print(f"Fallback used : {len(fb_pages)} / {len(has_gt)} (GT-table pages only)") - print(f"\n mean best_iou_primary : {_mean([r.best_iou_primary for r in has_gt]):.3f}") - print(f" mean best_iou_fallback : {_mean([r.best_iou_fallback for r in has_gt]):.3f}") - print(f" mean best_iou_final : {_mean([r.best_iou_final for r in has_gt]):.3f}") + print(f"\n mean best_iou_primary : {_mean([r.best_iou_primary for r in has_gt]):.3f}") + print(f" mean best_iou_fallback : {_mean([r.best_iou_fallback for r in has_gt]):.3f}") + print(f" mean best_iou_candidate : {_mean([r.best_iou_candidate for r in has_gt]):.3f} (all detected tables)") + print(f" mean best_iou_crop : {_mean([r.best_iou_crop for r in has_gt]):.3f} (score >= table_threshold)") # Q1 ─ primary miss vs low score print(f"\n── Q1: on {len(fb_pages)} fallback pages ──") @@ -232,22 +267,48 @@ def main() -> None: # Q3 ─ threshold sensitivity (simulate different table_threshold values) if has_gt: print(f"\n── Q3: threshold sensitivity (simulated, {len(has_gt)} GT-table pages) ──") - print(f" {'thresh':>7} {'fb_pages':>8} {'mean_iou_final':>14}") + print(f" Rule: fallback fires only when primary_tables >= 1 and score < thresh.") + print(f" Note: IoU values are pre-dedup proxies (val_005241-style collapses not captured).") + print(f" {'thresh':>7} {'fb_pages':>8} {'iou_crop_sim(pre-dedup)':>22}") for thresh in [0.30, 0.40, 0.50, 0.60, 0.70]: - sim_fb = sum(1 for r in has_gt if r.primary_max_score < thresh) - # If primary above thresh -> use primary IoU; else use fallback IoU - sim_ious = [ - r.best_iou_primary if r.primary_max_score >= thresh else r.best_iou_fallback - for r in has_gt - ] - print(f" {thresh:>7.2f} {sim_fb:>8} {_mean(sim_ious):>14.3f}") + # Only pages where primary found >= 1 table but none above thresh trigger fallback + sim_fb = sum( + 1 for r in has_gt if r.primary_tables > 0 and r.primary_max_score < thresh + ) + sim_ious = [] + for r in has_gt: + if r.primary_max_score >= thresh: + sim_ious.append(r.best_iou_primary) + elif r.primary_tables > 0: + # fallback fires: use fallback IoU as proxy + sim_ious.append(r.best_iou_fallback) + else: + # primary found zero tables → fallback skipped → no crop + sim_ious.append(0.0) + print(f" {thresh:>7.2f} {sim_fb:>8} {_mean(sim_ious):>22.3f}") + + # Table-level matching summary (GT-table pages only) + if has_gt: + gt_total = sum(r.gt_tables for r in has_gt) + pred_total = sum(r.num_crop_tables for r in has_gt) + m50 = sum(r.matched_50 for r in has_gt) + m75 = sum(r.matched_75 for r in has_gt) + prec50 = f"{m50 / pred_total:.3f}" if pred_total else "N/A" + prec75 = f"{m75 / pred_total:.3f}" if pred_total else "N/A" + print(f"\n── Table-level matching ({len(has_gt)} GT-table pages) ──") + print(f" GT tables total : {gt_total}") + print(f" crops total : {pred_total}") + print(f" matched@0.50 : {m50} recall={m50 / gt_total:.3f} precision={prec50}") + print(f" matched@0.75 : {m75} recall={m75 / gt_total:.3f} precision={prec75}") + print(f" missed GT tables : {gt_total - m50} (no crop with IoU >= 0.50)") + print(f" extra crops : {pred_total - m50} (crops not matching any GT at IoU >= 0.50)") # False-positive report: only printed when all pages have no GT table no_gt = [r for r in rows if r.gt_tables == 0] if no_gt and len(no_gt) == len(rows): fp_primary = sum(1 for r in no_gt if r.primary_tables > 0) fp_fallback = sum(1 for r in no_gt if r.fallback_used) - fp_crop = sum(1 for r in no_gt if r.final_tables > 0) + fp_crop = sum(1 for r in no_gt if r.num_crop_tables > 0) print(f"\n── False-positive rate ({len(no_gt)} table-free pages) ──") print(f" primary detected table : {fp_primary} / {len(no_gt)}") print(f" fallback triggered : {fp_fallback} / {len(no_gt)}") diff --git a/scripts/run_layout_batch.py b/scripts/run_layout_batch.py index ce383db..0c366f4 100644 --- a/scripts/run_layout_batch.py +++ b/scripts/run_layout_batch.py @@ -45,9 +45,9 @@ def parse_args() -> argparse.Namespace: help="output root (default: config.LAYOUT_OUTPUT)") p.add_argument("--primary-threshold", type=float, default=0.3, help="score cutoff inside build_layout_detector") - p.add_argument("--table-threshold", type=float, default=0.5, + p.add_argument("--table-threshold", type=float, default=0.3, help="score threshold for: fallback detector, fallback trigger, crop filter") - p.add_argument("--dedup-iou", type=float, default=0.5) + p.add_argument("--dedup-iou", type=float, default=0.7) p.add_argument("--no-fallback", action="store_true", help="disable TATR fallback (primary only)") p.add_argument("--require-table-gt", action="store_true", @@ -86,6 +86,11 @@ def main() -> None: crops_dir = out_dir / "crops" regions_dir.mkdir(parents=True, exist_ok=True) crops_dir.mkdir(parents=True, exist_ok=True) + stale = list(regions_dir.glob("*.json")) + list(crops_dir.glob("*.png")) + for f in stale: + f.unlink() + if stale: + print(f"[batch] cleared {len(stale)} stale artifact(s) from previous run") print("[batch] loading detectors ...") t0 = time.time() diff --git a/scripts/smoke_structure.py b/scripts/smoke_structure.py new file mode 100644 index 0000000..44e536e --- /dev/null +++ b/scripts/smoke_structure.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +"""Smoke: Phase 2 crop → TATR structure recognition handoff check. + +Picks up to --n crops from LAYOUT_OUTPUT/crops/ (Phase 2 output), runs each +through the full structure pipeline (model inference + normalize + validate), +and prints a one-line summary per crop. Writes a CSV summary. + +GPU required (T4 on Colab). CPU fallback works but is slow. +""" +from __future__ import annotations + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import argparse +import csv +import random +import time + +from src import config +from src.tatr_postprocess import dedup_row_col_bands, normalize_tatr_prediction, validate_grid_geometry +from src.tatr_raw import RAW_BOX_KEYS, RAW_LABEL_TO_KEY + + +class _ReasonCollector: + """Minimal logger shim: collects validate_grid_geometry failure reasons.""" + def __init__(self) -> None: + self.reasons: list[str] = [] + + def log(self, sample_id: str, phase: str, error_type: str, reason: str) -> None: + if reason not in self.reasons: + self.reasons.append(reason) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="Smoke: Phase 2 crops → TATR structure") + p.add_argument("--crops-dir", type=Path, default=None, + help="directory of crop PNGs (default: config.LAYOUT_OUTPUT/crops)") + p.add_argument("--n", type=int, default=5, help="number of crops to test") + p.add_argument("--seed", type=int, default=42, + help="random seed for crop sampling (default: 42)") + p.add_argument("--threshold", type=float, default=0.5, + help="TATR detection threshold") + p.add_argument("--out-dir", type=Path, default=None, + help="directory for smoke_structure.csv (default: config.LAYOUT_OUTPUT)") + return p.parse_args() + + +def main() -> None: + args = parse_args() + + from PIL import Image + import torch + from transformers import AutoImageProcessor, TableTransformerForObjectDetection + + crops_dir = args.crops_dir or config.LAYOUT_OUTPUT / "crops" + out_dir = args.out_dir or config.LAYOUT_OUTPUT + + all_crops = sorted(crops_dir.glob("*.png")) + if not all_crops: + print(f"[smoke] no crops found in {crops_dir}") + return + rng = random.Random(args.seed) + crops = sorted(rng.sample(all_crops, min(args.n, len(all_crops)))) + print(f"[smoke] {len(crops)} crops sampled (seed={args.seed}) from {crops_dir}") + + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"[smoke] loading {config.TATR_STRUCTURE_MODEL} on {device} ...") + t0 = time.time() + processor = AutoImageProcessor.from_pretrained( + config.TATR_STRUCTURE_MODEL, use_fast=False + ) + # Size dict fix: checkpoint only has longest_edge; add shortest_edge so resize works. + longest = processor.size.get("longest_edge", 1000) + processor.size = {"shortest_edge": min(800, longest), "longest_edge": longest} + model = TableTransformerForObjectDetection.from_pretrained( + config.TATR_STRUCTURE_MODEL + ).to(device).eval() + print(f"[smoke] model ready in {time.time() - t0:.1f}s") + + passed = warned = 0 + csv_rows: list[dict] = [] + + for crop_path in crops: + img = Image.open(crop_path).convert("RGB") + + inputs = processor(images=img, return_tensors="pt").to(device) + with torch.no_grad(): + outputs = model(**inputs) + target_sizes = torch.tensor([img.size[::-1]]) + result = processor.post_process_object_detection( + outputs, threshold=args.threshold, target_sizes=target_sizes + )[0] + + id2label = model.config.id2label + pred: dict = {k: [] for k in RAW_BOX_KEYS} + for label_id, score, box in zip( + result["labels"].tolist(), + result["scores"].tolist(), + result["boxes"].tolist(), + ): + key = RAW_LABEL_TO_KEY.get(id2label[label_id]) + if key: + pred[key].append({"bbox": [float(v) for v in box], + "score": float(score), + "label": id2label[label_id]}) + + pred = dedup_row_col_bands(pred) + canonical = normalize_tatr_prediction(pred) + rows_sorted = sorted(pred["row_boxes"], key=lambda r: r["bbox"][1]) + cols_sorted = sorted(pred["col_boxes"], key=lambda c: c["bbox"][0]) + + collector = _ReasonCollector() + valid = validate_grid_geometry( + rows_sorted, cols_sorted, canonical["cells"], + logger=collector, sample_id=crop_path.stem, + ) + + status = "OK " if valid else "WARN" + if valid: + passed += 1 + else: + warned += 1 + + reasons_str = "; ".join(collector.reasons) if collector.reasons else "" + print( + f" {status} {crop_path.name:<45}" + f" rows={canonical['num_rows']:>3} cols={canonical['num_cols']:>2}" + f" cells={len(canonical['cells']):>4} valid={valid}" + + (f" [{reasons_str}]" if reasons_str else "") + ) + csv_rows.append({ + "crop": crop_path.name, + "rows": canonical["num_rows"], + "cols": canonical["num_cols"], + "cells": len(canonical["cells"]), + "valid": valid, + "failure_reasons": reasons_str, + }) + + print(f"\n[smoke] {passed} OK / {warned} WARN out of {len(crops)}") + if warned == 0: + print("[smoke] structure handoff OK") + + csv_path = out_dir / "smoke_structure.csv" + with csv_path.open("w", newline="") as f: + w = csv.DictWriter(f, fieldnames=["crop", "rows", "cols", "cells", "valid", "failure_reasons"]) + w.writeheader() + w.writerows(csv_rows) + print(f"[smoke] wrote {csv_path}") + + +if __name__ == "__main__": + main() diff --git a/src/layout_parsing.py b/src/layout_parsing.py index 5b64e8b..2480c06 100644 --- a/src/layout_parsing.py +++ b/src/layout_parsing.py @@ -23,6 +23,8 @@ from . import bbox_utils TABLE_LABEL = "table" +DEFAULT_TABLE_SCORE = 0.3 +DEFAULT_TABLE_DEDUP_IOU = 0.7 # Raw detector label string -> canonical lowercase label. Explicit (not a blind `.lower()`) so the # two detectors' table strings reconcile to one "table" and the DocLayNet classes are pinned. @@ -122,8 +124,8 @@ def detect_layout( detector: Callable[[object], Sequence[Region]], fallback_detector: Callable[[object], Sequence[Region]] | None = None, *, - min_table_score: float = 0.5, - dedup_iou: float = 0.5, + min_table_score: float = DEFAULT_TABLE_SCORE, + dedup_iou: float = DEFAULT_TABLE_DEDUP_IOU, ) -> list[Region]: """Sequential-first layout detection with a low-confidence table fallback (DESIGN_SPEC §4.1). diff --git a/src/tatr_postprocess.py b/src/tatr_postprocess.py index 46bba35..ff9be4e 100644 --- a/src/tatr_postprocess.py +++ b/src/tatr_postprocess.py @@ -62,9 +62,9 @@ def validate_grid_geometry( ) -> bool: """Sanity-check a grid (DESIGN_SPEC §5.3). - Checks: negative dimensions, row/col sort order, adjacent overlap > 0.3, and - tiny cells (area < 100). Returns True when the grid is sane; failures are logged - if a FailureLogger is given. + Checks: missing row/column axes, negative dimensions, row/col sort order, adjacent + overlap > 0.3, and tiny cells (area < 100). Returns True when the grid is sane; + failures are logged if a FailureLogger is given. """ ok = True @@ -74,6 +74,11 @@ def fail(reason: str) -> None: if logger is not None: logger.log(sample_id, "phase1a", "grid_geometry", reason) + if not row_boxes: + fail("no row boxes detected") + if not col_boxes: + fail("no col boxes detected") + for r in row_boxes: x1, y1, x2, y2 = r["bbox"] if x2 <= x1 or y2 <= y1: @@ -329,13 +334,61 @@ def _mark_column_headers( cell["is_header"] = True +def _dedup_bands(boxes: list[dict], axis: int, overlap_threshold: float = 0.3) -> list[dict]: + """1-D NMS on row or col bands: drop an overlapping band, keeping the higher-score one. + + Processes bands sorted by start coordinate. For each band, if it overlaps the previous + kept band by more than overlap_threshold * min(extent), the higher-score box wins. + + axis=1: row bands (bbox[1], bbox[3]) + axis=0: col bands (bbox[0], bbox[2]) + """ + if len(boxes) < 2: + return list(boxes) + i0, i1 = axis, axis + 2 + sorted_boxes = sorted(boxes, key=lambda b: b["bbox"][i0]) + kept: list[dict] = [sorted_boxes[0]] + for box in sorted_boxes[1:]: + lo, hi = box["bbox"][i0], box["bbox"][i1] + prev = kept[-1] + plo, phi = prev["bbox"][i0], prev["bbox"][i1] + overlap = max(0.0, min(hi, phi) - max(lo, plo)) + smaller = min(hi - lo, phi - plo) + if smaller > 0 and overlap / smaller > overlap_threshold: + if box.get("score", 0.0) > prev.get("score", 0.0): + kept[-1] = box + else: + kept.append(box) + return kept + + +def dedup_row_col_bands(prediction: dict, overlap_threshold: float = 0.3) -> dict: + """Return a copy of prediction with overlapping row/col bands removed. + + Applies _dedup_bands to row_boxes (y-axis) and col_boxes (x-axis). Other keys + (spanning_cells, column_headers, …) are passed through unchanged. + """ + return { + **prediction, + "row_boxes": _dedup_bands(prediction.get("row_boxes", []), axis=1, + overlap_threshold=overlap_threshold), + "col_boxes": _dedup_bands(prediction.get("col_boxes", []), axis=0, + overlap_threshold=overlap_threshold), + } + + def normalize_tatr_prediction(prediction: dict) -> CanonicalTable: """TATR prediction -> canonical schema (same shape as the GT path). Expects row_boxes / col_boxes (and optional spanning_cells) as lists of dicts with a "bbox" key. When column_headers boxes are present (the GT structure XML and the TATR raw artifact both carry them), the cells they cover are flagged is_header. + + Overlapping row/col bands are deduped before grid construction so that the TATR + model's tendency to emit thin overlapping bands on dense non-financial crops does + not produce invalid grids. """ + prediction = dedup_row_col_bands(prediction) rows = sorted(prediction.get("row_boxes", []), key=lambda r: r["bbox"][1]) cols = sorted(prediction.get("col_boxes", []), key=lambda c: c["bbox"][0]) cells = boxes_to_grid(rows, cols, prediction.get("spanning_cells")) diff --git a/tests/test_tatr_postprocess.py b/tests/test_tatr_postprocess.py index e97a588..c540fd5 100644 --- a/tests/test_tatr_postprocess.py +++ b/tests/test_tatr_postprocess.py @@ -8,6 +8,7 @@ apply_spanning_cells, boxes_to_grid, can_convert_to_canonical, + dedup_row_col_bands, html_to_canonical, map_spanning_bbox_to_grid, normalize_tatr_prediction, @@ -151,6 +152,12 @@ def test_validate_grid_geometry_detects_tiny_cell(): assert validate_grid_geometry(rows, cols, cells) is False +def test_validate_grid_geometry_rejects_missing_rows_or_cols(): + rows, cols = _rows(), _cols() + assert validate_grid_geometry([], cols, []) is False + assert validate_grid_geometry(rows, [], []) is False + + # --- HTML parsing ----------------------------------------------------------------- def test_html_to_canonical_simple(): @@ -199,3 +206,101 @@ def test_gate_rejects_bad_cell_span(): bad = {"cells": [{"row_start": 1, "row_end": 1, "col_start": 0, "col_end": 1}]} ok, reason = can_convert_to_canonical(bad) assert ok is False + + +# --- dedup_row_col_bands --- + + +def _b(lo, hi, score=0.9, axis=1): + """Make a synthetic band box. axis=1 -> row (y), axis=0 -> col (x).""" + if axis == 1: + return {"bbox": [0.0, float(lo), 30.0, float(hi)], "score": score} + return {"bbox": [float(lo), 0.0, float(hi), 30.0], "score": score} + + +def test_dedup_no_overlap_keeps_all(): + rows = [_b(0, 10), _b(15, 25), _b(30, 40)] + result = dedup_row_col_bands({"row_boxes": rows, "col_boxes": []}) + assert len(result["row_boxes"]) == 3 + + +def test_dedup_overlap_keeps_higher_score(): + # Row A [0,10] score=0.5 overlaps Row B [5,15] score=0.9 -> keep B + rows = [_b(0, 10, score=0.5), _b(5, 15, score=0.9)] + result = dedup_row_col_bands({"row_boxes": rows, "col_boxes": []}) + kept = result["row_boxes"] + assert len(kept) == 1 + assert kept[0]["score"] == 0.9 + + +def test_dedup_overlap_keeps_lower_start_when_score_wins(): + # Row A [0,10] score=0.9 overlaps Row B [5,15] score=0.3 -> keep A + rows = [_b(0, 10, score=0.9), _b(5, 15, score=0.3)] + result = dedup_row_col_bands({"row_boxes": rows, "col_boxes": []}) + kept = result["row_boxes"] + assert len(kept) == 1 + assert kept[0]["score"] == 0.9 + + +def test_dedup_chain_three_overlapping(): + # A [0,10] score=0.9, B [5,15] score=0.3, C [13,23] score=0.8 + # A vs B: overlap=5, smaller=10, ratio=0.5 > 0.3 -> keep A (higher score) + # kept[-1]=A vs C: overlap=max(0,min(10,23)-max(0,13))=0 -> keep C + # Result: [A, C] + rows = [_b(0, 10, score=0.9), _b(5, 15, score=0.3), _b(13, 23, score=0.8)] + result = dedup_row_col_bands({"row_boxes": rows, "col_boxes": []}) + assert len(result["row_boxes"]) == 2 + scores = {r["score"] for r in result["row_boxes"]} + assert scores == {0.9, 0.8} + + +def test_dedup_col_axis(): + # Two overlapping col boxes (x-axis) + cols = [_b(0, 10, score=0.4, axis=0), _b(6, 16, score=0.7, axis=0)] + result = dedup_row_col_bands({"row_boxes": [], "col_boxes": cols}) + kept = result["col_boxes"] + assert len(kept) == 1 + assert kept[0]["score"] == 0.7 + + +def test_dedup_no_overlap_at_exact_threshold(): + # overlap=3, smaller=10 -> ratio=0.3 (NOT > 0.3) -> keep both + rows = [_b(0, 10, score=0.9), _b(7, 17, score=0.8)] + result = dedup_row_col_bands({"row_boxes": rows, "col_boxes": []}) + assert len(result["row_boxes"]) == 2 + + +def test_dedup_other_keys_passthrough(): + pred = { + "row_boxes": [_b(0, 10)], + "col_boxes": [_b(0, 10, axis=0)], + "spanning_cells": [{"bbox": [0, 0, 10, 10]}], + "column_headers": [{"bbox": [0, 0, 30, 5]}], + } + result = dedup_row_col_bands(pred) + assert result["spanning_cells"] is pred["spanning_cells"] + assert result["column_headers"] is pred["column_headers"] + + +def test_dedup_makes_previously_invalid_grid_valid(): + # Two rows heavily overlapping -> validate fails before dedup, passes after + rows = [_b(0, 20, score=0.9), _b(5, 25, score=0.5)] + cols = [_b(0, 10, axis=0), _b(10, 20, axis=0), _b(20, 30, axis=0)] + pred = {"row_boxes": rows, "col_boxes": cols} + + canonical_before = normalize_tatr_prediction(pred) + valid_before = validate_grid_geometry( + sorted(pred["row_boxes"], key=lambda r: r["bbox"][1]), + sorted(pred["col_boxes"], key=lambda c: c["bbox"][0]), + canonical_before["cells"], + ) + assert not valid_before + + deduped = dedup_row_col_bands(pred) + canonical_after = normalize_tatr_prediction(deduped) + valid_after = validate_grid_geometry( + sorted(deduped["row_boxes"], key=lambda r: r["bbox"][1]), + sorted(deduped["col_boxes"], key=lambda c: c["bbox"][0]), + canonical_after["cells"], + ) + assert valid_after