From 1df2dc5b2fafa3579de15bf94119d94de04fa02f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:29:13 +0000 Subject: [PATCH 1/2] Initial plan From 34d708b36e3c20e2b35fff8cc2a6e44df764658a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:53:30 +0000 Subject: [PATCH 2/2] feat: Add members parameter to @dy.filter() for member-specific collection filters Co-authored-by: borchero <22455425+borchero@users.noreply.github.com> --- dataframely/_filter.py | 30 +- dataframely/collection/_base.py | 41 ++- dataframely/collection/collection.py | 182 +++++++++--- tests/collection/test_filter_members.py | 370 ++++++++++++++++++++++++ 4 files changed, 567 insertions(+), 56 deletions(-) create mode 100644 tests/collection/test_filter_members.py diff --git a/dataframely/_filter.py b/dataframely/_filter.py index 3e39ad12..c3859398 100644 --- a/dataframely/_filter.py +++ b/dataframely/_filter.py @@ -1,7 +1,7 @@ # Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import Generic, TypeVar import polars as pl @@ -12,11 +12,18 @@ class Filter(Generic[C]): """Internal class representing logic for filtering members of a collection.""" - def __init__(self, logic: Callable[[C], pl.LazyFrame]) -> None: + def __init__( + self, + logic: Callable[[C], pl.LazyFrame], + members: Sequence[str] | None = None, + ) -> None: self.logic = logic + self.members = members -def filter() -> Callable[[Callable[[C], pl.LazyFrame]], Filter[C]]: +def filter( + members: Sequence[str] | None = None, +) -> Callable[[Callable[[C], pl.LazyFrame]], Filter[C]]: """Mark a function as filters for rows in the members of a collection. The name of the function will be used as the name of the filter. The name must not @@ -26,10 +33,19 @@ def filter() -> Callable[[Callable[[C], pl.LazyFrame]], Filter[C]]: A filter receives a collection as input and must return a data frame like the following: - - The columns must be a superset of the common primary keys across all members. - - The rows must provide the primary keys which ought to be *kept* across the + - The columns must be a superset of the primary keys of the applicable members + (the common primary key across all members if ``members`` is not specified, or + the common primary key across the specified members otherwise). + - The rows must provide the primary keys which ought to be *kept* in the applicable members. The filter results in the removal of rows which are lost as the result - of inner-joining members onto the return value of this function. + of inner-joining applicable members onto the return value of this function. + + Args: + members: The names of the collection members to which this filter applies. + If ``None`` (the default), the filter applies to all non-ignored members, + using the collection's common primary key. + If specified, the filter only applies to the listed members and the join + key used is the common primary key of those members. Attention: Make sure to provide unique combinations of the primary keys or the filters @@ -43,6 +59,6 @@ def filter() -> Callable[[Callable[[C], pl.LazyFrame]], Filter[C]]: """ def decorator(validation_fn: Callable[[C], pl.LazyFrame]) -> Filter[C]: - return Filter(logic=validation_fn) + return Filter(logic=validation_fn, members=members) return decorator diff --git a/dataframely/collection/_base.py b/dataframely/collection/_base.py index ff4140c6..32a889c6 100644 --- a/dataframely/collection/_base.py +++ b/dataframely/collection/_base.py @@ -132,11 +132,42 @@ def __new__( # 1) Check that there are overlapping primary keys that allow the application # of filters. if len(non_ignored_member_schemas) > 0 and len(result.filters) > 0: - if len(_common_primary_key(non_ignored_member_schemas)) == 0: - raise ImplementationError( - "Members of a collection must have an overlapping primary key " - "but did not find any." - ) + # Filters without 'members' apply to all non-ignored members and require + # their common primary key to be non-empty. + global_filters = [f for f in result.filters.values() if f.members is None] + if len(global_filters) > 0: + if len(_common_primary_key(non_ignored_member_schemas)) == 0: + raise ImplementationError( + "Members of a collection must have an overlapping primary key " + "but did not find any." + ) + + # Filters with 'members' apply only to specific members. + for filter_name, filter_obj in result.filters.items(): + if filter_obj.members is None: + continue + if len(filter_obj.members) == 0: + raise ImplementationError( + f"Filter '{filter_name}' must specify at least one member." + ) + for m in filter_obj.members: + if m not in result.members: + raise ImplementationError( + f"Filter '{filter_name}' references unknown member '{m}'." + ) + if result.members[m].ignored_in_filters: + raise ImplementationError( + f"Filter '{filter_name}' references member '{m}' which is " + "ignored in filters." + ) + specified_schemas = [ + result.members[m].schema for m in filter_obj.members + ] + if len(_common_primary_key(specified_schemas)) == 0: + raise ImplementationError( + f"Members specified in filter '{filter_name}' must have an " + "overlapping primary key but did not find any." + ) # 2) Check that filter names do not overlap with any column or rule names if len(result.members) > 0: diff --git a/dataframely/collection/collection.py b/dataframely/collection/collection.py index 665b85ae..86d0327b 100644 --- a/dataframely/collection/collection.py +++ b/dataframely/collection/collection.py @@ -42,7 +42,7 @@ from dataframely.random import Generator from dataframely.schema import _schema_from_dict -from ._base import BaseCollection, CollectionMember +from ._base import BaseCollection, CollectionMember, _common_primary_key from .filter_result import CollectionFilterResult if sys.version_info >= (3, 11): @@ -316,6 +316,13 @@ def _filters_match() -> bool: empty_right = other.create_empty() for name in filters_lhs: + # Members specification must match + _lhs_m = filters_lhs[name].members + _rhs_m = filters_rhs[name].members + lhs_members = set(_lhs_m) if _lhs_m is not None else None + rhs_members = set(_rhs_m) if _rhs_m is not None else None + if lhs_members != rhs_members: + return False lhs = filters_lhs[name].logic(empty_left) rhs = filters_rhs[name].logic(empty_right) if lhs.serialize() != rhs.serialize(): @@ -436,28 +443,51 @@ def validate( if filters := cls._filters(): result_cls = cls._init(members) - primary_key = cls.common_primary_key() - filter_names = list(filters.keys()) - keep = [ - filter.logic(result_cls).select( - *primary_key, pl.lit(True).alias(name) - ) - for name, filter in filters.items() - ] - members = { - name: ( - _join_all( - lf, *keep, on=primary_key, how="left", maintain_order="left" - ) - .filter( - all_rules_required( - filter_names, null_is_valid=False, schema_name=name + common_pk = cls.common_primary_key() + + # Build per-filter (lazy_frame, join_key) pairs + filter_keep: dict[str, tuple[pl.LazyFrame, list[str]]] = {} + for filter_name, f in filters.items(): + if f.members is None: + join_key = common_pk + else: + join_key = sorted( + _common_primary_key( + cls.members()[m].schema for m in f.members ) ) - .drop(filter_names) + filter_keep[filter_name] = ( + f.logic(result_cls).select( + *join_key, pl.lit(True).alias(filter_name) + ), + join_key, ) - for name, lf in members.items() - } + + new_members: dict[str, pl.LazyFrame] = {} + for member_name, lf in members.items(): + # Find filters applicable to this member + applicable = [ + (fname, flt_lf, join_key) + for fname, (flt_lf, join_key) in filter_keep.items() + if _filter_applies_to_member(filters[fname].members, member_name) + ] + if not applicable: + new_members[member_name] = lf + continue + result_lf = lf + for fname, flt_lf, join_key in applicable: + result_lf = result_lf.join( + flt_lf, on=join_key, how="left", maintain_order="left" + ) + applicable_names = [fname for fname, _, _ in applicable] + new_members[member_name] = result_lf.filter( + all_rules_required( + applicable_names, + null_is_valid=False, + schema_name=member_name, + ) + ).drop(applicable_names) + members = new_members return cls._init(members) @@ -489,18 +519,42 @@ def is_valid(cls, data: Mapping[str, FrameType], /, *, cast: bool = False) -> bo return False members[member] = data[member].lazy() - # Make sure that inner-joining all filters does not remove any rows - if filters := cls._filters().values(): + # Make sure that applying all filters does not remove any rows from applicable members + if filters := cls._filters(): result_cls = cls._init(members) - primary_key = cls.common_primary_key() - keep = [filter.logic(result_cls).select(primary_key) for filter in filters] - joined = _join_all(*keep, on=primary_key, how="inner") - removed_rows = pl.collect_all( - data[member].lazy().join(joined, on=primary_key, how="anti") - for member in cls.members() - if member in data - ) - return all(df.is_empty() for df in removed_rows) + common_pk = cls.common_primary_key() + + # Compute filter results with their per-filter join keys + filter_results: dict[str, tuple[pl.LazyFrame, list[str]]] = {} + for name, f in filters.items(): + if f.members is None: + join_key = common_pk + else: + join_key = sorted( + _common_primary_key(cls.members()[m].schema for m in f.members) + ) + filter_results[name] = (f.logic(result_cls).select(join_key), join_key) + + # For each non-ignored member, check each applicable filter + check_frames = [] + for member_name in cls.members(): + if member_name not in data: + continue + member_info = cls.members()[member_name] + if member_info.ignored_in_filters: + continue + for filter_name, (filter_keep, join_key) in filter_results.items(): + f = filters[filter_name] + if f.members is not None and member_name not in f.members: + continue + check_frames.append( + data[member_name] + .lazy() + .join(filter_keep, on=join_key, how="anti") + ) + if check_frames: + removed_rows = pl.collect_all(check_frames) + return all(df.is_empty() for df in removed_rows) return True @@ -578,13 +632,18 @@ class HospitalInvoiceData(dy.Collection): result_cls = cls._init(results) primary_key = cls.common_primary_key() - keep: dict[str, pl.LazyFrame] = {} - for name, filter in filters.items(): + # keep maps filter_name -> (lazy_frame_of_kept_keys, join_key) + keep: dict[str, tuple[pl.LazyFrame, list[str]]] = {} + for name, f in filters.items(): + if f.members is None: + join_key = primary_key + else: + join_key = sorted( + _common_primary_key(cls.members()[m].schema for m in f.members) + ) keep[name] = ( - filter.logic(result_cls) - .select(primary_key) - .pipe(collect_if, eager) - .lazy() + f.logic(result_cls).select(join_key).pipe(collect_if, eager).lazy(), + join_key, ) drop: dict[str, pl.LazyFrame] = {} @@ -605,11 +664,22 @@ class HospitalInvoiceData(dy.Collection): if member_info.ignored_in_filters: continue + # Only apply filters that are applicable to this member + applicable_keep = { + name: (filter_keep, filter_key) + for name, (filter_keep, filter_key) in keep.items() + if _filter_applies_to_member(filters[name].members, member_name) + } + + all_filter_columns = list(applicable_keep.keys()) + list(drop.keys()) + if not all_filter_columns: + continue + lf_with_eval = filtered.lazy() - for name, filter_keep in keep.items(): + for name, (filter_keep, filter_key) in applicable_keep.items(): lf_with_eval = lf_with_eval.join( filter_keep.lazy().with_columns(pl.lit(True).alias(name)), - on=primary_key, + on=filter_key, how="left", maintain_order="left", ).with_columns(pl.col(name).fill_null(False)) @@ -626,7 +696,6 @@ class HospitalInvoiceData(dy.Collection): # Filtering `lf_with_eval` by the rows for which all joins # "succeeded", we can identify the rows that pass all the filters. We # keep these rows for the result. - all_filter_columns = list(keep.keys()) + list(drop.keys()) results[member_name] = lf_with_eval.filter( pl.all_horizontal(all_filter_columns) ).drop(all_filter_columns) @@ -848,8 +917,11 @@ def serialize(cls) -> str: for name, info in cls.members().items() }, "filters": { - name: filter.logic(cls.create_empty()) - for name, filter in cls._filters().items() + name: { + "logic": f.logic(cls.create_empty()), + "members": sorted(f.members) if f.members is not None else None, + } + for name, f in cls._filters().items() }, } return json.dumps(result, cls=SchemaJSONEncoder) @@ -1347,8 +1419,17 @@ def deserialize_collection(data: str, strict: bool = True) -> type[Collection] | { "__annotations__": annotations, **{ - name: Filter(logic=lambda _, logic=logic: logic) # type: ignore - for name, logic in decoded["filters"].items() + name: ( + # New format: filter_data is a dict with "logic" and "members" + Filter( + logic=lambda _, logic=filter_data["logic"]: logic, # type: ignore + members=filter_data.get("members"), + ) + if isinstance(filter_data, dict) + # Old format: filter_data is a LazyFrame directly + else Filter(logic=lambda _, logic=filter_data: logic) # type: ignore + ) + for name, filter_data in decoded["filters"].items() }, }, ) @@ -1363,6 +1444,19 @@ def deserialize_collection(data: str, strict: bool = True) -> type[Collection] | # --------------------------------------- UTILS -------------------------------------- # +def _filter_applies_to_member( + filter_members: Sequence[str] | None, member_name: str +) -> bool: + """Return True if the filter applies to the given member. + + If ``filter_members`` is None, the filter applies to all non-ignored members. + Otherwise, the filter only applies to the members listed in ``filter_members``. + """ + if filter_members is None: + return True + return member_name in filter_members + + def _join_all( *dfs: pl.LazyFrame, on: list[str], diff --git a/tests/collection/test_filter_members.py b/tests/collection/test_filter_members.py new file mode 100644 index 00000000..6ae47277 --- /dev/null +++ b/tests/collection/test_filter_members.py @@ -0,0 +1,370 @@ +# Copyright (c) QuantCo 2025-2026 +# SPDX-License-Identifier: BSD-3-Clause + +import polars as pl +import pytest +from polars.testing import assert_frame_equal + +import dataframely as dy +from dataframely._filter import Filter +from dataframely.exc import ImplementationError +from dataframely.testing import create_collection, create_schema + +# ------------------------------------------------------------------------------------ # +# SCHEMAS # +# ------------------------------------------------------------------------------------ # + + +class InvoiceSchema(dy.Schema): + invoice_id = dy.Integer(primary_key=True) + amount = dy.Integer(nullable=False) + + +class DiagnosisSchema(dy.Schema): + invoice_id = dy.Integer(primary_key=True) + diagnosis_code = dy.String(primary_key=True) + diagnosis_date = dy.Integer(nullable=False) + + +class InvoiceOnlyFilter(dy.Collection): + """Filter applied only to invoices member.""" + + invoices: dy.LazyFrame[InvoiceSchema] + diagnoses: dy.LazyFrame[DiagnosisSchema] + + @dy.filter(members=["invoices"]) + def filter_invoices(self) -> pl.LazyFrame: + # Keep only invoices with positive amount + return self.invoices.filter(pl.col("amount") > 0) + + +class DiagnosisOnlyFilter(dy.Collection): + """Filter applied only to diagnoses member, using data from both members.""" + + invoices: dy.LazyFrame[InvoiceSchema] + diagnoses: dy.LazyFrame[DiagnosisSchema] + + @dy.filter(members=["diagnoses"]) + def filter_diagnoses(self) -> pl.LazyFrame: + # Keep diagnoses with invoice_id that exists in invoices + return self.diagnoses.join( + self.invoices.select("invoice_id"), + on="invoice_id", + how="semi", + ) + + +class BothMembersFilter(dy.Collection): + """Filter with explicit members list covering all members.""" + + invoices: dy.LazyFrame[InvoiceSchema] + diagnoses: dy.LazyFrame[DiagnosisSchema] + + @dy.filter(members=["invoices", "diagnoses"]) + def filter_both(self) -> pl.LazyFrame: + # Keep invoice_ids that appear in both + return self.invoices.join( + self.diagnoses.select("invoice_id").unique(), + on="invoice_id", + ).select("invoice_id") + + +# ------------------------------------------------------------------------------------ # +# FILTER TESTS # +# ------------------------------------------------------------------------------------ # + + +@pytest.mark.parametrize("eager", [True, False]) +def test_member_filter_only_affects_specified_member(eager: bool) -> None: + invoices = pl.LazyFrame({"invoice_id": [1, 2, 3], "amount": [10, -5, 20]}) + diagnoses = pl.LazyFrame( + { + "invoice_id": [1, 2, 3], + "diagnosis_code": ["A01", "B02", "C03"], + "diagnosis_date": [100, 200, 300], + } + ) + + result, failure = InvoiceOnlyFilter.filter( + {"invoices": invoices, "diagnoses": diagnoses}, eager=eager + ) + + # Only invoice 2 (amount=-5) should be removed from invoices + assert_frame_equal( + result.invoices.collect().sort("invoice_id"), + pl.DataFrame({"invoice_id": [1, 3], "amount": [10, 20]}), + ) + # Diagnoses should NOT be filtered by this filter + assert result.diagnoses.collect().height == 3 + + # Failure info for invoices should record filter failures + assert failure["invoices"].counts() == {"filter_invoices": 1} + # Failure info for diagnoses should be empty (filter doesn't apply) + assert len(failure["diagnoses"]) == 0 + + +@pytest.mark.parametrize("eager", [True, False]) +def test_member_filter_uses_member_primary_key(eager: bool) -> None: + """Filter on diagnoses uses the full primary key of diagnoses.""" + invoices = pl.LazyFrame({"invoice_id": [1, 2], "amount": [10, 20]}) + # invoice_id=1 has two diagnoses; only one should pass + diagnoses = pl.LazyFrame( + { + "invoice_id": [1, 1, 2], + "diagnosis_code": ["A01", "B02", "C03"], + "diagnosis_date": [100, 200, 300], + } + ) + + result, failure = DiagnosisOnlyFilter.filter( + {"invoices": invoices, "diagnoses": diagnoses}, eager=eager + ) + + # All diagnoses have matching invoice_ids, so none should be filtered + assert result.diagnoses.collect().height == 3 + # Invoices not affected + assert result.invoices.collect().height == 2 + + +@pytest.mark.parametrize("eager", [True, False]) +def test_member_filter_removes_unmatched_diagnoses(eager: bool) -> None: + """Diagnoses with invoice_id not in invoices are removed.""" + invoices = pl.LazyFrame({"invoice_id": [1], "amount": [10]}) + diagnoses = pl.LazyFrame( + { + "invoice_id": [1, 2], + "diagnosis_code": ["A01", "B02"], + "diagnosis_date": [100, 200], + } + ) + + result, failure = DiagnosisOnlyFilter.filter( + {"invoices": invoices, "diagnoses": diagnoses}, eager=eager + ) + + # Only diagnosis for invoice_id=1 should remain + assert result.diagnoses.collect().height == 1 + assert result.diagnoses.collect()["invoice_id"].to_list() == [1] + # Invoices not affected + assert result.invoices.collect().height == 1 + + assert failure["diagnoses"].counts() == {"filter_diagnoses": 1} + assert len(failure["invoices"]) == 0 + + +@pytest.mark.parametrize("eager", [True, False]) +def test_member_filter_on_multiple_members(eager: bool) -> None: + """Filter that applies to multiple members uses common primary key.""" + invoices = pl.LazyFrame({"invoice_id": [1, 2, 3], "amount": [10, 20, 30]}) + diagnoses = pl.LazyFrame( + { + "invoice_id": [1, 2, 4], + "diagnosis_code": ["A01", "B02", "C03"], + "diagnosis_date": [100, 200, 300], + } + ) + + result, failure = BothMembersFilter.filter( + {"invoices": invoices, "diagnoses": diagnoses}, eager=eager + ) + + # Only invoice_ids present in both are kept + assert sorted(result.invoices.collect()["invoice_id"].to_list()) == [1, 2] + assert sorted(result.diagnoses.collect()["invoice_id"].to_list()) == [1, 2] + + assert failure["invoices"].counts() == {"filter_both": 1} + assert failure["diagnoses"].counts() == {"filter_both": 1} + + +# ------------------------------------------------------------------------------------ # +# VALIDATE TESTS # +# ------------------------------------------------------------------------------------ # + + +@pytest.mark.parametrize("eager", [True, False]) +def test_is_valid_member_filter(eager: bool) -> None: + invoices_valid = pl.LazyFrame({"invoice_id": [1, 2], "amount": [10, 20]}) + invoices_invalid = pl.LazyFrame({"invoice_id": [1, 2], "amount": [10, -5]}) + diagnoses = pl.LazyFrame( + { + "invoice_id": [1, 2], + "diagnosis_code": ["A01", "B02"], + "diagnosis_date": [100, 200], + } + ) + + assert InvoiceOnlyFilter.is_valid( + {"invoices": invoices_valid, "diagnoses": diagnoses} + ) + assert not InvoiceOnlyFilter.is_valid( + {"invoices": invoices_invalid, "diagnoses": diagnoses} + ) + + +@pytest.mark.parametrize("eager", [True, False]) +def test_validate_member_filter_lazy(eager: bool) -> None: + invoices = pl.LazyFrame({"invoice_id": [1, 2], "amount": [10, 20]}) + diagnoses = pl.LazyFrame( + { + "invoice_id": [1, 2], + "diagnosis_code": ["A01", "B02"], + "diagnosis_date": [100, 200], + } + ) + + validated = InvoiceOnlyFilter.validate( + {"invoices": invoices, "diagnoses": diagnoses}, eager=eager + ) + assert validated.invoices.collect().height == 2 + assert validated.diagnoses.collect().height == 2 + + +# ------------------------------------------------------------------------------------ # +# SERIALIZATION TESTS # +# ------------------------------------------------------------------------------------ # + + +def test_serialize_includes_members() -> None: + import json + + serialized = json.loads(InvoiceOnlyFilter.serialize()) + filter_data = serialized["filters"]["filter_invoices"] + assert isinstance(filter_data, dict) + assert filter_data["members"] == ["invoices"] + + +def test_serialize_null_members_for_global_filter() -> None: + import json + + collection = create_collection( + "test", + { + "s1": create_schema("schema1", {"a": dy.Int64(primary_key=True)}), + "s2": create_schema("schema2", {"a": dy.Int64(primary_key=True)}), + }, + {"filter1": Filter(lambda c: c.s1.join(c.s2, on="a"))}, + ) + serialized = json.loads(collection.serialize()) + filter_data = serialized["filters"]["filter1"] + assert isinstance(filter_data, dict) + assert filter_data["members"] is None + + +def test_roundtrip_matches_with_members() -> None: + serialized = InvoiceOnlyFilter.serialize() + decoded = dy.deserialize_collection(serialized) + assert InvoiceOnlyFilter.matches(decoded) + + +def test_matches_differs_when_members_differ() -> None: + """Two collections with same filter logic but different members should not match.""" + + class CollA(dy.Collection): + invoices: dy.LazyFrame[InvoiceSchema] + diagnoses: dy.LazyFrame[DiagnosisSchema] + + @dy.filter(members=["invoices"]) + def my_filter(self) -> pl.LazyFrame: + return self.invoices.filter(pl.col("amount") > 0) + + class CollB(dy.Collection): + invoices: dy.LazyFrame[InvoiceSchema] + diagnoses: dy.LazyFrame[DiagnosisSchema] + + @dy.filter(members=["diagnoses"]) + def my_filter(self) -> pl.LazyFrame: + return self.invoices.filter(pl.col("amount") > 0) + + assert not CollA.matches(CollB) + + +def test_matches_with_same_members() -> None: + class CollA(dy.Collection): + invoices: dy.LazyFrame[InvoiceSchema] + diagnoses: dy.LazyFrame[DiagnosisSchema] + + @dy.filter(members=["invoices"]) + def my_filter(self) -> pl.LazyFrame: + return self.invoices.filter(pl.col("amount") > 0) + + class CollB(dy.Collection): + invoices: dy.LazyFrame[InvoiceSchema] + diagnoses: dy.LazyFrame[DiagnosisSchema] + + @dy.filter(members=["invoices"]) + def my_filter(self) -> pl.LazyFrame: + return self.invoices.filter(pl.col("amount") > 0) + + assert CollA.matches(CollB) + + +# ------------------------------------------------------------------------------------ # +# IMPLEMENTATION TESTS # +# ------------------------------------------------------------------------------------ # + + +def test_filter_members_unknown_member() -> None: + with pytest.raises( + ImplementationError, + match=r"Filter 'f' references unknown member 'nonexistent'", + ): + create_collection( + "test", + { + "s1": create_schema("s1", {"a": dy.Integer(primary_key=True)}), + }, + filters={"f": Filter(lambda c: c.s1, members=["nonexistent"])}, + ) + + +def test_filter_members_empty_list() -> None: + with pytest.raises( + ImplementationError, + match=r"Filter 'f' must specify at least one member", + ): + create_collection( + "test", + { + "s1": create_schema("s1", {"a": dy.Integer(primary_key=True)}), + }, + filters={"f": Filter(lambda c: c.s1, members=[])}, + ) + + +def test_filter_members_no_common_primary_key() -> None: + with pytest.raises( + ImplementationError, + match=r"Members specified in filter 'f' must have an overlapping primary key", + ): + create_collection( + "test", + { + "s1": create_schema("s1", {"a": dy.Integer(primary_key=True)}), + "s2": create_schema("s2", {"b": dy.Integer(primary_key=True)}), + }, + filters={"f": Filter(lambda c: c.s1, members=["s1", "s2"])}, + ) + + +def test_filter_members_ignored_member() -> None: + from typing import Annotated + + from dataframely.testing import create_collection_raw + + schema_a = create_schema("a", {"a": dy.Integer(primary_key=True)}) + schema_b = create_schema("b", {"a": dy.Integer(primary_key=True)}) + with pytest.raises( + ImplementationError, + match=r"Filter 'f' references member 'ignored' which is ignored in filters", + ): + create_collection_raw( + "test", + annotations={ + "a": dy.LazyFrame[schema_a], + "ignored": Annotated[ + dy.LazyFrame[schema_b], + dy.CollectionMember(ignored_in_filters=True), + ], + }, + filters={"f": Filter(lambda c: c.a, members=["ignored"])}, + )