diff --git a/CHANGELOG.md b/CHANGELOG.md index 36aeaeb..0dc6ef8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- `prune_chart_to_tree` (default `False`): when on, chart rows whose + `chart_strain_field` value isn't a tree tip are filtered out before + drawing. CLI form: + `--prune-chart-to-tree / --no-prune-chart-to-tree`. + ([#6](https://github.com/jbloomlab/tree-annotated-plot/issues/6)) + ## [0.2.2] - 2026-05-09 ### Fixed diff --git a/src/tree_annotated_plot/_config.py b/src/tree_annotated_plot/_config.py index 75c1c98..659a62f 100644 --- a/src/tree_annotated_plot/_config.py +++ b/src/tree_annotated_plot/_config.py @@ -109,9 +109,14 @@ class PlotConfig: "When off (default), tree tips not present in the chart's strain " "set are a fatal error. When on, those tips (and any internal " "nodes whose subtrees become empty) are dropped before drawing, " - "with single-child internals collapsed into their kept child. " - "Chart strains not present in the tree are *always* fatal " - "regardless of this flag — pruning would silently lose plot data.", + "with single-child internals collapsed into their kept child.", + ] = False + + prune_chart_to_tree: Annotated[ + bool, + "When off (default), chart strains not present in the tree are a " + "fatal error. When on, chart rows whose `chart_strain_field` " + "value isn't a tree tip are filtered out before drawing.", ] = False strict_version: Annotated[ diff --git a/src/tree_annotated_plot/_plot.py b/src/tree_annotated_plot/_plot.py index 9ad3fb6..01742c5 100644 --- a/src/tree_annotated_plot/_plot.py +++ b/src/tree_annotated_plot/_plot.py @@ -55,6 +55,7 @@ def plot( scale_bar: bool = False, branch_length_units: str | None = None, prune_tree_to_chart: bool = False, + prune_chart_to_tree: bool = False, strict_version: bool = True, connect_leader_to_label: bool = False, strain_label_font_size: float = 10.0, @@ -82,6 +83,7 @@ def plot( scale_bar=scale_bar, branch_length_units=branch_length_units, prune_tree_to_chart=prune_tree_to_chart, + prune_chart_to_tree=prune_chart_to_tree, strict_version=strict_version, connect_leader_to_label=connect_leader_to_label, strain_label_font_size=strain_label_font_size, @@ -168,12 +170,33 @@ def _build( chart_strains = _extract_chart_strains(spec, axis_hits, config.chart_strain_field) + if config.prune_chart_to_tree and (set(chart_strains) - set(tip_names)): + _prune_chart_spec_to_strains( + spec, + chart_strain_field=config.chart_strain_field, + keep_strains=set(tip_names), + ) + chart = alt.Chart.from_dict(spec) + chart_strains = _extract_chart_strains( + spec, axis_hits, config.chart_strain_field + ) + if not chart_strains: + raise ValueError( + "prune_chart_to_tree=True dropped every chart row: no " + "chart strain matched any tree tip under " + f"chart_strain_field={config.chart_strain_field!r} / " + f"tree_strain_field={config.tree_strain_field!r}. Pruning " + "is meant for charts that bundle a superset of strains; " + "an empty intersection suggests a wrong field choice." + ) + _reconcile_tips_and_strains( tree_strains=tip_names, chart_strains=chart_strains, chart_strain_field=config.chart_strain_field, tree_strain_field=config.tree_strain_field, prune_tree_to_chart=config.prune_tree_to_chart, + prune_chart_to_tree=config.prune_chart_to_tree, chart_spec=spec, tree_source=tree, ) @@ -641,6 +664,70 @@ def _extract_chart_strains( return _extract_field_values_from_spec_data(spec, chart_strain_field) +def _prune_chart_spec_to_strains( + spec: dict, *, chart_strain_field: str, keep_strains: set[str] +) -> None: + """Filter a Vega-Lite spec in place to drop rows outside `keep_strains`. + + Mutates three kinds of structure: + - top-level `datasets` entries (each a list of row-dicts). + - inline `data.values` lists anywhere in the spec tree. + - explicit `sort` lists on encoding channels bound to + `chart_strain_field` (any other `sort` is left alone). + + Rows that don't carry `chart_strain_field` at all are preserved + (we have no signal to drop them). URL-backed data raises — we + can't fetch + filter at plot time, mirroring `_extract_chart_strains`. + """ + + def row_kept(row: Any) -> bool: + if not isinstance(row, dict): + return True + if chart_strain_field not in row: + return True + return row[chart_strain_field] in keep_strains + + datasets = spec.get("datasets") if isinstance(spec, dict) else None + if isinstance(datasets, dict): + for name, rows in list(datasets.items()): + if isinstance(rows, list): + datasets[name] = [row for row in rows if row_kept(row)] + + def walk(node: Any) -> None: + if isinstance(node, dict): + data = node.get("data") + if isinstance(data, dict): + if "url" in data: + raise ValueError( + f"chart references data via URL ({data['url']!r}); " + "URL data is not supported, so prune_chart_to_tree " + "cannot filter it. Materialize the data inline " + "(via alt.Chart(df) with a pandas DataFrame) before " + "saving the chart." + ) + if "values" in data and isinstance(data["values"], list): + data["values"] = [row for row in data["values"] if row_kept(row)] + encoding = node.get("encoding") + if isinstance(encoding, dict): + for channel in encoding.values(): + if not isinstance(channel, dict): + continue + if channel.get("field") != chart_strain_field: + continue + sort = channel.get("sort") + if isinstance(sort, list): + channel["sort"] = [s for s in sort if s in keep_strains] + for k, v in node.items(): + if k in ("data", "datasets"): + continue + walk(v) + elif isinstance(node, list): + for item in node: + walk(item) + + walk(spec) + + def _extract_field_values_from_spec_data(spec: dict, field: str) -> list[str]: """Walk spec for inline / named data and return distinct values of `field`. @@ -708,13 +795,16 @@ def _reconcile_tips_and_strains( chart_strain_field: str, tree_strain_field: str, prune_tree_to_chart: bool, + prune_chart_to_tree: bool, chart_spec: dict, tree_source: Any, ) -> None: """Verify tree strains and chart strains are reconcilable. Three asymmetries: - - chart strains not in tree → always fatal. + - chart strains not in tree → fatal unless `prune_chart_to_tree=True` + (in which case the chart spec has already been pre-filtered upstream + and this set is expected to be empty by the time we get here). - tree tips not in chart → fatal unless `prune_tree_to_chart=True`. - (duplicate tree_strain_field values across tips → handled by the separate `_check_no_duplicate_tip_strains`.) @@ -728,7 +818,9 @@ def _reconcile_tips_and_strains( chart_minus_tree = chart_set - tree_set tree_minus_chart = tree_set - chart_set - if not chart_minus_tree and (not tree_minus_chart or prune_tree_to_chart): + chart_ok = not chart_minus_tree + tree_ok = not tree_minus_chart or prune_tree_to_chart + if chart_ok and tree_ok: return hints = _candidate_field_hints( @@ -748,6 +840,7 @@ def _reconcile_tips_and_strains( chart_minus_tree=chart_minus_tree, tree_minus_chart=tree_minus_chart, prune_tree_to_chart=prune_tree_to_chart, + prune_chart_to_tree=prune_chart_to_tree, hints=hints, ) ) @@ -762,14 +855,16 @@ def _format_strain_mismatch( chart_minus_tree: set[str], tree_minus_chart: set[str], prune_tree_to_chart: bool, + prune_chart_to_tree: bool, hints: list[str], ) -> str: parts: list[str] = [] - if chart_minus_tree: + if chart_minus_tree and not prune_chart_to_tree: parts.append( f"{len(chart_minus_tree)} chart strain(s) are not present in the " - "tree (these would be silently dropped if we pruned, so this is " - "always fatal)." + "tree. Pass `prune_chart_to_tree=True` to drop the offending " + "chart rows automatically (use with care — this discards plot " + "data)." ) if tree_minus_chart and not prune_tree_to_chart: parts.append( @@ -783,7 +878,7 @@ def _format_strain_mismatch( ) parts.append("Sample chart_strain_field values: " f"{sorted(chart_strains)[:5]}") parts.append("Sample tree_strain_field values: " f"{sorted(tree_strains)[:5]}") - if chart_minus_tree: + if chart_minus_tree and not prune_chart_to_tree: parts.append(f"Sample chart-only values: {sorted(chart_minus_tree)[:5]}") if tree_minus_chart and not prune_tree_to_chart: parts.append(f"Sample tree-only values: {sorted(tree_minus_chart)[:5]}") diff --git a/tests/test_reconciliation.py b/tests/test_reconciliation.py index 772446f..abf97e2 100644 --- a/tests/test_reconciliation.py +++ b/tests/test_reconciliation.py @@ -73,7 +73,7 @@ def _chart_for_strains(strains: list[str], *, height: int = 200) -> alt.Chart: ) -# ---------- chart-not-in-tree (always fatal) ---------- +# ---------- chart-not-in-tree (fatal unless prune_chart_to_tree) ---------- def test_chart_strain_not_in_tree_is_fatal_default() -> None: @@ -90,9 +90,9 @@ def test_chart_strain_not_in_tree_is_fatal_default() -> None: ) -def test_chart_strain_not_in_tree_is_fatal_even_with_prune() -> None: +def test_chart_strain_not_in_tree_is_fatal_with_only_tree_prune() -> None: """`prune_tree_to_chart=True` only drops *tree* tips. A chart strain - not in the tree still raises — pruning would silently lose plot data.""" + not in the tree still raises — the two flags are orthogonal.""" chart = _chart_for_strains(["A1", "A2", "X"]) with pytest.raises(ValueError, match="not present in the tree"): tree_annotated_plot.plot( @@ -105,6 +105,97 @@ def test_chart_strain_not_in_tree_is_fatal_even_with_prune() -> None: ) +def test_chart_strain_not_in_tree_succeeds_with_prune_chart_to_tree() -> None: + """`prune_chart_to_tree=True` filters chart rows whose strain isn't a + tree tip. The resulting chart's strain-axis sort matches the kept tree + tip order, and the dropped strain is gone from the chart's data.""" + chart = _chart_for_strains(["A1", "A2", "A3", "B1", "B2", "X"]) + out = tree_annotated_plot.plot( + _auspice_two_clades(), + chart, + chart_strain_field="strain", + tree_strain_field="name", + branch_length="div", + prune_chart_to_tree=True, + ) + assert isinstance(out, alt.HConcatChart) + spec = out.to_dict() + # The strain-axis sort on the user-chart panel should be exactly the + # tree's tip order — no `X`. + sorts = [] + for ch in out.hconcat: + ch_spec = ch.to_dict() + enc = ch_spec.get("encoding", {}) + for channel in enc.values(): + if isinstance(channel, dict) and channel.get("field") == "strain": + if isinstance(channel.get("sort"), list): + sorts.append(channel["sort"]) + assert sorts, "expected at least one strain-axis encoding with a sort" + for s in sorts: + assert "X" not in s + assert set(s) <= {"A1", "A2", "A3", "B1", "B2"} + # And the user-chart data (now in `datasets`) should not contain X rows. + for rows in (spec.get("datasets") or {}).values(): + if isinstance(rows, list) and rows and isinstance(rows[0], dict): + if "strain" in rows[0]: + assert all(r["strain"] != "X" for r in rows) + + +def test_prune_chart_to_tree_zero_overlap_raises() -> None: + """If pruning would drop every chart row (no overlap with tree tips), + raise a clear error rather than producing an empty plot.""" + chart = _chart_for_strains(["X", "Y", "Z"]) + with pytest.raises(ValueError, match="dropped every chart row"): + tree_annotated_plot.plot( + _auspice_two_clades(), + chart, + chart_strain_field="strain", + tree_strain_field="name", + branch_length="div", + prune_chart_to_tree=True, + ) + + +def test_prune_chart_and_tree_combined() -> None: + """Both flags may be on at once: the chart has extras (X) and the tree + has tips the chart lacks (B1, B2). Pruning is bidirectional.""" + chart = _chart_for_strains(["A1", "A2", "A3", "X"]) + out = tree_annotated_plot.plot( + _auspice_two_clades(), + chart, + chart_strain_field="strain", + tree_strain_field="name", + branch_length="div", + prune_tree_to_chart=True, + prune_chart_to_tree=True, + ) + assert isinstance(out, alt.HConcatChart) + # Final intersection is {A1, A2, A3}: chart loses X, tree loses B1/B2. + found_sort = None + for ch in out.hconcat: + enc = ch.to_dict().get("encoding", {}) + for channel in enc.values(): + if isinstance(channel, dict) and channel.get("field") == "strain": + if isinstance(channel.get("sort"), list): + found_sort = channel["sort"] + assert found_sort is not None + assert set(found_sort) == {"A1", "A2", "A3"} + + +def test_prune_chart_to_tree_default_error_mentions_flag() -> None: + """When chart has strains not in tree and the user hasn't opted in, + the error message should suggest `prune_chart_to_tree=True`.""" + chart = _chart_for_strains(["A1", "A2", "A3", "B1", "B2", "X"]) + with pytest.raises(ValueError, match="prune_chart_to_tree=True"): + tree_annotated_plot.plot( + _auspice_two_clades(), + chart, + chart_strain_field="strain", + tree_strain_field="name", + branch_length="div", + ) + + # ---------- tree-not-in-chart (fatal unless prune) ----------