Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions dataframely/_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) QuantCo 2025-2026
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Callable
from collections.abc import Callable, Sequence
from typing import Generic, TypeVar

import polars as pl
Expand All @@ -12,11 +12,18 @@
class Filter(Generic[C]):
"""Internal class representing logic for filtering members of a collection."""

def __init__(self, logic: Callable[[C], pl.LazyFrame]) -> None:
def __init__(
self,
logic: Callable[[C], pl.LazyFrame],
members: Sequence[str] | None = None,
) -> None:
self.logic = logic
self.members = members


def filter() -> Callable[[Callable[[C], pl.LazyFrame]], Filter[C]]:
def filter(
members: Sequence[str] | None = None,
) -> Callable[[Callable[[C], pl.LazyFrame]], Filter[C]]:
"""Mark a function as filters for rows in the members of a collection.

The name of the function will be used as the name of the filter. The name must not
Expand All @@ -26,10 +33,19 @@ def filter() -> Callable[[Callable[[C], pl.LazyFrame]], Filter[C]]:
A filter receives a collection as input and must return a data frame like the
following:

- The columns must be a superset of the common primary keys across all members.
- The rows must provide the primary keys which ought to be *kept* across the
- The columns must be a superset of the primary keys of the applicable members
(the common primary key across all members if ``members`` is not specified, or
the common primary key across the specified members otherwise).
- The rows must provide the primary keys which ought to be *kept* in the applicable
members. The filter results in the removal of rows which are lost as the result
of inner-joining members onto the return value of this function.
of inner-joining applicable members onto the return value of this function.

Args:
members: The names of the collection members to which this filter applies.
If ``None`` (the default), the filter applies to all non-ignored members,
using the collection's common primary key.
If specified, the filter only applies to the listed members and the join
key used is the common primary key of those members.

Attention:
Make sure to provide unique combinations of the primary keys or the filters
Expand All @@ -43,6 +59,6 @@ def filter() -> Callable[[Callable[[C], pl.LazyFrame]], Filter[C]]:
"""

def decorator(validation_fn: Callable[[C], pl.LazyFrame]) -> Filter[C]:
return Filter(logic=validation_fn)
return Filter(logic=validation_fn, members=members)

return decorator
41 changes: 36 additions & 5 deletions dataframely/collection/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,42 @@ def __new__(
# 1) Check that there are overlapping primary keys that allow the application
# of filters.
if len(non_ignored_member_schemas) > 0 and len(result.filters) > 0:
if len(_common_primary_key(non_ignored_member_schemas)) == 0:
raise ImplementationError(
"Members of a collection must have an overlapping primary key "
"but did not find any."
)
# Filters without 'members' apply to all non-ignored members and require
# their common primary key to be non-empty.
global_filters = [f for f in result.filters.values() if f.members is None]
if len(global_filters) > 0:
if len(_common_primary_key(non_ignored_member_schemas)) == 0:
raise ImplementationError(
"Members of a collection must have an overlapping primary key "
"but did not find any."
)

# Filters with 'members' apply only to specific members.
for filter_name, filter_obj in result.filters.items():
if filter_obj.members is None:
continue
if len(filter_obj.members) == 0:
raise ImplementationError(
f"Filter '{filter_name}' must specify at least one member."
)
for m in filter_obj.members:
if m not in result.members:
raise ImplementationError(
f"Filter '{filter_name}' references unknown member '{m}'."
)
if result.members[m].ignored_in_filters:
raise ImplementationError(
f"Filter '{filter_name}' references member '{m}' which is "
"ignored in filters."
)
specified_schemas = [
result.members[m].schema for m in filter_obj.members
]
if len(_common_primary_key(specified_schemas)) == 0:
raise ImplementationError(
f"Members specified in filter '{filter_name}' must have an "
"overlapping primary key but did not find any."
)

# 2) Check that filter names do not overlap with any column or rule names
if len(result.members) > 0:
Expand Down
182 changes: 138 additions & 44 deletions dataframely/collection/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from dataframely.random import Generator
from dataframely.schema import _schema_from_dict

from ._base import BaseCollection, CollectionMember
from ._base import BaseCollection, CollectionMember, _common_primary_key
from .filter_result import CollectionFilterResult

if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -316,6 +316,13 @@ def _filters_match() -> bool:
empty_right = other.create_empty()

for name in filters_lhs:
# Members specification must match
_lhs_m = filters_lhs[name].members
_rhs_m = filters_rhs[name].members
lhs_members = set(_lhs_m) if _lhs_m is not None else None
rhs_members = set(_rhs_m) if _rhs_m is not None else None
if lhs_members != rhs_members:
return False
lhs = filters_lhs[name].logic(empty_left)
rhs = filters_rhs[name].logic(empty_right)
if lhs.serialize() != rhs.serialize():
Expand Down Expand Up @@ -436,28 +443,51 @@ def validate(

if filters := cls._filters():
result_cls = cls._init(members)
primary_key = cls.common_primary_key()
filter_names = list(filters.keys())
keep = [
filter.logic(result_cls).select(
*primary_key, pl.lit(True).alias(name)
)
for name, filter in filters.items()
]
members = {
name: (
_join_all(
lf, *keep, on=primary_key, how="left", maintain_order="left"
)
.filter(
all_rules_required(
filter_names, null_is_valid=False, schema_name=name
common_pk = cls.common_primary_key()

# Build per-filter (lazy_frame, join_key) pairs
filter_keep: dict[str, tuple[pl.LazyFrame, list[str]]] = {}
for filter_name, f in filters.items():
if f.members is None:
join_key = common_pk
else:
join_key = sorted(
_common_primary_key(
cls.members()[m].schema for m in f.members
)
)
.drop(filter_names)
filter_keep[filter_name] = (
f.logic(result_cls).select(
*join_key, pl.lit(True).alias(filter_name)
),
join_key,
)
for name, lf in members.items()
}

new_members: dict[str, pl.LazyFrame] = {}
for member_name, lf in members.items():
# Find filters applicable to this member
applicable = [
(fname, flt_lf, join_key)
for fname, (flt_lf, join_key) in filter_keep.items()
if _filter_applies_to_member(filters[fname].members, member_name)
]
if not applicable:
new_members[member_name] = lf
continue
result_lf = lf
for fname, flt_lf, join_key in applicable:
result_lf = result_lf.join(
flt_lf, on=join_key, how="left", maintain_order="left"
)
applicable_names = [fname for fname, _, _ in applicable]
new_members[member_name] = result_lf.filter(
all_rules_required(
applicable_names,
null_is_valid=False,
schema_name=member_name,
)
).drop(applicable_names)
members = new_members

return cls._init(members)

Expand Down Expand Up @@ -489,18 +519,42 @@ def is_valid(cls, data: Mapping[str, FrameType], /, *, cast: bool = False) -> bo
return False
members[member] = data[member].lazy()

# Make sure that inner-joining all filters does not remove any rows
if filters := cls._filters().values():
# Make sure that applying all filters does not remove any rows from applicable members
if filters := cls._filters():
result_cls = cls._init(members)
primary_key = cls.common_primary_key()
keep = [filter.logic(result_cls).select(primary_key) for filter in filters]
joined = _join_all(*keep, on=primary_key, how="inner")
removed_rows = pl.collect_all(
data[member].lazy().join(joined, on=primary_key, how="anti")
for member in cls.members()
if member in data
)
return all(df.is_empty() for df in removed_rows)
common_pk = cls.common_primary_key()

# Compute filter results with their per-filter join keys
filter_results: dict[str, tuple[pl.LazyFrame, list[str]]] = {}
for name, f in filters.items():
if f.members is None:
join_key = common_pk
else:
join_key = sorted(
_common_primary_key(cls.members()[m].schema for m in f.members)
)
filter_results[name] = (f.logic(result_cls).select(join_key), join_key)

# For each non-ignored member, check each applicable filter
check_frames = []
for member_name in cls.members():
if member_name not in data:
continue
member_info = cls.members()[member_name]
if member_info.ignored_in_filters:
continue
for filter_name, (filter_keep, join_key) in filter_results.items():
f = filters[filter_name]
if f.members is not None and member_name not in f.members:
continue
check_frames.append(
data[member_name]
.lazy()
.join(filter_keep, on=join_key, how="anti")
)
if check_frames:
removed_rows = pl.collect_all(check_frames)
return all(df.is_empty() for df in removed_rows)

return True

Expand Down Expand Up @@ -578,13 +632,18 @@ class HospitalInvoiceData(dy.Collection):
result_cls = cls._init(results)
primary_key = cls.common_primary_key()

keep: dict[str, pl.LazyFrame] = {}
for name, filter in filters.items():
# keep maps filter_name -> (lazy_frame_of_kept_keys, join_key)
keep: dict[str, tuple[pl.LazyFrame, list[str]]] = {}
for name, f in filters.items():
if f.members is None:
join_key = primary_key
else:
join_key = sorted(
_common_primary_key(cls.members()[m].schema for m in f.members)
)
keep[name] = (
filter.logic(result_cls)
.select(primary_key)
.pipe(collect_if, eager)
.lazy()
f.logic(result_cls).select(join_key).pipe(collect_if, eager).lazy(),
join_key,
)

drop: dict[str, pl.LazyFrame] = {}
Expand All @@ -605,11 +664,22 @@ class HospitalInvoiceData(dy.Collection):
if member_info.ignored_in_filters:
continue

# Only apply filters that are applicable to this member
applicable_keep = {
name: (filter_keep, filter_key)
for name, (filter_keep, filter_key) in keep.items()
if _filter_applies_to_member(filters[name].members, member_name)
}

all_filter_columns = list(applicable_keep.keys()) + list(drop.keys())
if not all_filter_columns:
continue

lf_with_eval = filtered.lazy()
for name, filter_keep in keep.items():
for name, (filter_keep, filter_key) in applicable_keep.items():
lf_with_eval = lf_with_eval.join(
filter_keep.lazy().with_columns(pl.lit(True).alias(name)),
on=primary_key,
on=filter_key,
how="left",
maintain_order="left",
).with_columns(pl.col(name).fill_null(False))
Expand All @@ -626,7 +696,6 @@ class HospitalInvoiceData(dy.Collection):
# Filtering `lf_with_eval` by the rows for which all joins
# "succeeded", we can identify the rows that pass all the filters. We
# keep these rows for the result.
all_filter_columns = list(keep.keys()) + list(drop.keys())
results[member_name] = lf_with_eval.filter(
pl.all_horizontal(all_filter_columns)
).drop(all_filter_columns)
Expand Down Expand Up @@ -848,8 +917,11 @@ def serialize(cls) -> str:
for name, info in cls.members().items()
},
"filters": {
name: filter.logic(cls.create_empty())
for name, filter in cls._filters().items()
name: {
"logic": f.logic(cls.create_empty()),
"members": sorted(f.members) if f.members is not None else None,
}
for name, f in cls._filters().items()
},
}
return json.dumps(result, cls=SchemaJSONEncoder)
Expand Down Expand Up @@ -1347,8 +1419,17 @@ def deserialize_collection(data: str, strict: bool = True) -> type[Collection] |
{
"__annotations__": annotations,
**{
name: Filter(logic=lambda _, logic=logic: logic) # type: ignore
for name, logic in decoded["filters"].items()
name: (
# New format: filter_data is a dict with "logic" and "members"
Filter(
logic=lambda _, logic=filter_data["logic"]: logic, # type: ignore
members=filter_data.get("members"),
)
if isinstance(filter_data, dict)
# Old format: filter_data is a LazyFrame directly
else Filter(logic=lambda _, logic=filter_data: logic) # type: ignore
)
for name, filter_data in decoded["filters"].items()
},
},
)
Expand All @@ -1363,6 +1444,19 @@ def deserialize_collection(data: str, strict: bool = True) -> type[Collection] |
# --------------------------------------- UTILS -------------------------------------- #


def _filter_applies_to_member(
filter_members: Sequence[str] | None, member_name: str
) -> bool:
"""Return True if the filter applies to the given member.

If ``filter_members`` is None, the filter applies to all non-ignored members.
Otherwise, the filter only applies to the members listed in ``filter_members``.
"""
if filter_members is None:
return True
return member_name in filter_members


def _join_all(
*dfs: pl.LazyFrame,
on: list[str],
Expand Down
Loading