diff --git a/docs/language/reference/dataset_methods.md b/docs/language/reference/dataset_methods.md index 9a9a701..05503fd 100644 --- a/docs/language/reference/dataset_methods.md +++ b/docs/language/reference/dataset_methods.md @@ -20,6 +20,7 @@ The Substrait helper surface behind these methods is split by semantic role: | `group_by` | `def group_by(self, columns: list[ColumnExpr]) -> Self` | Define grouping keys using scalar expressions. | | `agg` | `def agg(self, measures: list[AggregateMeasure]) -> Self` | Apply aggregate measures over the current relation or current grouping. | | `generate` | `def generate(self, generator: GeneratorApplication) -> Self` | Apply a relation-shaping generator such as `explode(...)` with explicit output aliases. | +| `with_window_column` | `def with_window_column(self, name: str, application: WindowFunctionApplication) -> Self` | Add or replace one projected column using a placed window function. | | `order_by` | `def order_by(self, columns: list[ColumnExpr]) -> Self` | Sort rows by scalar expressions or ordering helpers such as `asc(...)` and `desc(...)`. | | `limit` | `def limit(self, n: int) -> Self` | Cap row count. | | `explode` | `def explode(self) -> Self` | Compatibility marker for the older EXPLODE extension path. Prefer `generate(explode(...))`. | @@ -69,6 +70,7 @@ def enrich(orders: LazyFrame[Order]) -> LazyFrame[Order]: - `join(...)` is constrained to same-carrier inputs and the boolean join predicate surface shown in the signature. - `select(...)` preserves projection shape; explicit projection lists are represented today through `with_column(...)` and scalar-expression builders. - `generate(...)` preserves all input columns and appends generated output aliases. Alias collisions are rejected during planning/lowering. +- `with_window_column(...)` currently supports ranking helpers over explicit window specs and lowers through Substrait window relations. Backend execution support is tracked separately from logical planning support. - `DataFrame[T]` exposes materialized metadata and preview text; row-level accessors belong to the materialized DataFrame API surface. - Query-block and scoped DSL surfaces lower into these builder APIs rather than defining separate method semantics. @@ -77,3 +79,4 @@ def enrich(orders: LazyFrame[Order]) -> LazyFrame[Order]: - [Filter builders](builders/filters.md) - [Aggregate builders](builders/aggregates.md) - [Projection builders](builders/projections.md) +- [Window functions](functions/windows.md) diff --git a/docs/language/reference/functions/index.md b/docs/language/reference/functions/index.md index e65ea90..adc070d 100644 --- a/docs/language/reference/functions/index.md +++ b/docs/language/reference/functions/index.md @@ -9,10 +9,11 @@ Today the concrete shipped surfaces are documented here: - [Projection builders](../builders/projections.md) - [Generator and table-valued functions](generators.md) - [Nested data functions](nested.md) +- [Window functions](windows.md) The canonical scalar literal helper is `lit(...)`. Typed literal helpers construct the same scalar-expression representation. -The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, generators, and nested data. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. +The current registry-backed helper surface is registered in the package-owned function registry. Registry types live in `src/function_registry.incn`, the shared package registry lives in `src/functions/registry.incn`, and concrete public helper entries are produced by `function_registry.add(...)` decorators in individual `src/functions//.incn` modules. The registry-backed families are references, literals, casts, operators, predicates, conditionals, math, ordering, aggregates, generators, nested data, and windows. Each runtime entry exposes a stable function reference such as `inql.functions.col`, namespace, canonical name, typed lifecycle metadata (`since`, versioned changes, and optional deprecation), RFC 024 policy category, function class, null behavior, alias policy, aggregate modifier policy, and Substrait mapping metadata. Checked function signatures come from the public helper declaration, not from a second hand-written registry signature. The registry is the source for non-derivable machine facts. Public helper declarations are the source for argument names, argument types, and return types. Docstrings remain human-facing explanation, examples, and parameter intent. The `registry-metadata` check validates the checked API metadata projections produced from public facade aliases, registry decorators, and decorated callable signatures. Runtime registry entries are lazy and process-local: they support helper execution and lowering for loaded helpers, while the complete public catalog comes from checked metadata. This matters for generated docs, diagnostics, Prism lowering, and backend capability checks as the catalog grows. @@ -35,6 +36,7 @@ The registered helper surface currently includes: | `abs(...)`, `ceil(...)`, `floor(...)`, `round(...)` | scalar | registered Substrait math scalar mappings; `round(...)` is currently the single-argument form | | `array(...)`, `cardinality(...)`, `array_contains(...)`, `arrays_overlap(...)`, `array_position(...)`, `element_at(...)`, `array_sort(...)`, `array_distinct(...)`, `array_except(...)`, `array_intersect(...)`, `array_union(...)`, `array_join(...)`, `array_slice(...)`, `array_reverse(...)`, `array_flatten(...)`, `map_from_arrays(...)`, `map_extract(...)`, `map_contains_key(...)`, `map_keys(...)`, `map_values(...)`, `map_entries(...)`, `named_struct(...)` | scalar | registered nested scalar helpers backed by Substrait extension mappings; `map_contains_key(...)` lowers as a documented predicate rewrite | | `explode(...)`, `explode_outer(...)`, `posexplode(...)`, `posexplode_outer(...)` | generator | relation-extension mappings consumed by `generate(...)`; positional forms use zero-based positions | +| `window()`, `row_number()`, `rank()`, `dense_rank()` | window | `window()` builds structural window-spec metadata; ranking helpers lower through `ConsistentPartitionWindowRel` when placed with `with_window_column(...)` | | `asc(...)`, `desc(...)`, `asc_nulls_first(...)`, `asc_nulls_last(...)`, `desc_nulls_first(...)`, `desc_nulls_last(...)` | ordering | structural sort-field helpers consumed by `order_by(...)` and lowered to Substrait `SortRel.sorts` | | `sum(...)`, `count()`, `count_expr(...)`, `avg(...)`, `min(...)`, `max(...)` | aggregate | registered Substrait extension functions; `count_expr(...)` is a compatibility spelling for future `count(expr)` helper overloading | | `count_distinct(...)`, `count_if(...)` | aggregate | compatibility helpers that lower through aggregate modifiers over canonical `count` semantics | diff --git a/docs/language/reference/functions/windows.md b/docs/language/reference/functions/windows.md new file mode 100644 index 0000000..37185ae --- /dev/null +++ b/docs/language/reference/functions/windows.md @@ -0,0 +1,33 @@ +# Window Functions (Reference) + +Window helpers are relation-aware. A window function application produces one output value per input row while reading a +partition of related rows. It is not an ordinary scalar expression and must be placed through a projection-like dataset +method. + +```incan +from pub::inql import LazyFrame +from pub::inql.functions import col, desc, rank, window +from models import Order + +def ranked_orders(orders: LazyFrame[Order]) -> LazyFrame[Order]: + return orders.with_window_column( + "customer_rank", + rank().over(window().partition_by([col("customer_id")]).order_by([desc(col("amount"))])), + ) +``` + +The current foundation slice includes: + +| Function | Meaning | Placement | +| --- | --- | --- | +| `window()` | Build an empty window specification. | Structural builder used before `.over(...)`. | +| `row_number()` | Assign a sequential row number inside the ordered window. | Use `.over(window().order_by(...))`, then `with_window_column(...)`. | +| `rank()` | Rank rows with gaps after ties inside the ordered window. | Use `.over(window().order_by(...))`, then `with_window_column(...)`. | +| `dense_rank()` | Rank rows without gaps after ties inside the ordered window. | Use `.over(window().order_by(...))`, then `with_window_column(...)`. | + +`WindowSpec.partition_by(...)` replaces the partition expressions. `WindowSpec.order_by(...)` replaces the ordering +expressions. Ranking helpers require explicit ordering; missing ordering is rejected during logical lowering. + +`with_window_column(name, application)` preserves input columns and adds or replaces `name` using add-or-replace +projection semantics. Each call lowers one window projection through Substrait `ConsistentPartitionWindowRel` with a +registry-backed function anchor. Backend execution support is separate from this logical planning surface. diff --git a/docs/language/reference/substrait/operator_catalog.md b/docs/language/reference/substrait/operator_catalog.md index 327ad49..81bfecb 100644 --- a/docs/language/reference/substrait/operator_catalog.md +++ b/docs/language/reference/substrait/operator_catalog.md @@ -34,7 +34,7 @@ The following table maps InQL plan capabilities to Substrait logical relations a | Group by / aggregates | `AggregateRel` with scalar grouping keys and aggregate measures; grouping sets are tracked as a distinct capability below | core | | Rollup / cube / grouping sets | `AggregateRel` with multiple groupings | core | | Distinct rows | `AggregateRel` with grouping keys and no measures | core | -| Window / analytic functions | `ProjectRel` with window expressions | core | +| Window / analytic functions | `ConsistentPartitionWindowRel` with partition/order expressions and registered window function anchors | core | | Sort | `SortRel` | core | | Limit / offset | `FetchRel` | core | | Union, intersect, except | `SetRel` with the appropriate set operation enum | core | diff --git a/docs/release_notes/v0_1.md b/docs/release_notes/v0_1.md index 5c23085..d337e4e 100644 --- a/docs/release_notes/v0_1.md +++ b/docs/release_notes/v0_1.md @@ -17,6 +17,7 @@ Entries will be filled in as work lands (link RFCs and PRs when applicable). - **Common scalar functions:** The first RFC 018 slice adds registry-backed math helpers for `abs(...)`, `ceil(...)`, `floor(...)`, and single-argument `round(...)`, with Substrait mappings and DataFusion-backed execution coverage. - **Nested data functions:** RFC 020 adds registry-backed scalar helpers for array construction/access, cardinality, containment, overlap, sorting, set-like operations, joining, slicing, reversing, scalar array flattening, map construction/access, map key/value/entry extraction, map key containment, and named struct construction. These helpers lower through Substrait extension metadata and execute through the DataFusion-backed Session path without introducing generator semantics. - **Generator functions:** RFC 021 adds registry-backed generator applications for `explode(...)`, `explode_outer(...)`, `posexplode(...)`, and `posexplode_outer(...)`. Generators remain relation-shaping operations applied with `generate(...)`; they preserve input columns, require explicit output aliases, and lower through the current Substrait extension-relation gap encoding. +- **Window functions:** RFC 019 adds the first window-function planning slice with `window()` specs, `row_number()`, `rank()`, `dense_rank()`, and `with_window_column(...)`. Ranking windows require explicit ordering and lower through Substrait `ConsistentPartitionWindowRel`; backend execution support remains a separate adapter capability. - **Function registry:** RFC 014 adds declaration-site registry decorators for the current public helper surface, including stable function references, checked signature projection, lifecycle metadata, behavior categories, alias policy, Substrait mapping categories, and checked API metadata drift validation. - **Function extension policy:** RFC 024 policy metadata now distinguishes portable core functions, namespaced extension-only functions, opt-in compatibility aliases, engine-specific functions, and rejected compatibility requests without adding an extension plugin system or backend-owned semantics. - **Projection:** builder-based `with_column`, `add`, `mul`, and literal expression helpers now lower derived columns through Prism, Substrait, and Session execution. diff --git a/docs/rfcs/019_window_functions.md b/docs/rfcs/019_window_functions.md index 7e6fb1c..b509d88 100644 --- a/docs/rfcs/019_window_functions.md +++ b/docs/rfcs/019_window_functions.md @@ -1,6 +1,6 @@ # InQL RFC 019: Window functions -- **Status:** Draft +- **Status:** In Progress - **Created:** 2026-04-27 - **Author(s):** Danny Meijer (@dannymeijer) - **Related:** @@ -11,7 +11,7 @@ - InQL RFC 016 (core aggregate functions) - **Issue:** [InQL #36](https://github.com/dannys-code-corner/InQL/issues/36) - **RFC PR:** — -- **Written against:** Incan v0.2 +- **Written against:** Incan v0.3-era InQL - **Shipped in:** — ## Summary @@ -40,19 +40,18 @@ Window functions also force a clearer relation between row-level expressions and ## Guide-level explanation (how authors think about it) -Authors should be able to rank and compare rows within a partition: +Authors can rank rows within a partition using the builder surface: ```incan -from pub::inql.functions import col, desc, lag, rank, window +from pub::inql.functions import col, desc, rank, window ranked = ( orders - .with_column("customer_rank", rank().over(window().partition_by([col("customer_id")]).order_by([desc(col("amount"))]))) - .with_column("previous_amount", lag(col("amount"), 1).over(window().partition_by([col("customer_id")]).order_by([col("created_at")]))) + .with_window_column("customer_rank", rank().over(window().partition_by([col("customer_id")]).order_by([desc(col("amount"))]))) ) ``` -The exact builder syntax may evolve, but authors should understand that a window function returns a row-level value computed with access to nearby or related rows. +The exact query-block syntax may evolve, but authors should understand that a window function returns a row-level value computed with access to nearby or related rows. ## Reference-level explanation (precise rules) @@ -108,10 +107,17 @@ No current InQL function should be reclassified silently as a window function. A - **Execution / interchange** — Prism and Substrait lowering must preserve window partitioning, ordering, frames, and function identity. - **Documentation** — docs should clearly separate aggregate functions from window functions. -## Unresolved questions +## Design Decisions + +### Resolved + +- The first implementation slice exposes explicit `with_window_column(...)` projection-like placement rather than accepting window functions in arbitrary scalar-expression positions. +- Ranking helpers require explicit `order_by(...)` in the window spec. InQL does not invent a silent default ordering. +- The current foundation slice lowers `row_number`, `rank`, and `dense_rank` through `ConsistentPartitionWindowRel` with registry-backed function anchors. +- DataFusion execution for window relations is not claimed until a backend adapter slice explicitly supports the lowered window relation. + +### Remaining - What default frame should InQL use for ordered window functions? -- Should window functions be allowed in `WHERE` or only in projection/order positions? - Should null treatment use explicit `IGNORE NULLS` / `RESPECT NULLS` style modifiers? - - +- How should `lag`, `lead`, first/last/nth value functions, aggregate-over-window calls, and query-block `OVER (...)` syntax be phased on top of the foundation model? diff --git a/docs/rfcs/README.md b/docs/rfcs/README.md index c42c434..4ec010f 100644 --- a/docs/rfcs/README.md +++ b/docs/rfcs/README.md @@ -25,7 +25,7 @@ InQL uses its **own** RFC series (starting at 000), independent of the [Incan la | [016][rfc-016] | Draft | Core aggregate functions | | | [017][rfc-017] | Draft | Aggregate modifiers | | | [018][rfc-018] | Draft | Common scalar function catalog | | -| [019][rfc-019] | Draft | Window functions | | +| [019][rfc-019] | In Progress | Window functions | | | [020][rfc-020] | Draft | Nested data functions | | | [021][rfc-021] | In Progress | Generator and table-valued functions | | | [022][rfc-022] | Draft | Semi-structured and format functions | | diff --git a/src/dataset/mod.incn b/src/dataset/mod.incn index e9b31b1..9a1b134 100644 --- a/src/dataset/mod.incn +++ b/src/dataset/mod.incn @@ -23,6 +23,7 @@ The current method-chain surface in this module is the explicit builder-based AP - `group_by(columns: list[ColumnExpr])` - `agg(measures: list[AggregateMeasure])` - `generate(generator: GeneratorApplication)` +- `with_window_column(name: str, application: WindowFunctionApplication)` - plus the structural operators `join`, `select`, `order_by`, `limit`, and `explode` Illustrative current-shape examples: @@ -56,6 +57,7 @@ from rust::substrait::proto import Plan, Rel from aggregate_builders import AggregateMeasure from generator_builders import GeneratorApplication from projection_builders import ColumnExpr +from window_builders import WindowFunctionApplication from dataset.materialization import DataFrameMaterialization from substrait.errors import SubstraitLoweringError from substrait.schema_registry import named_table_columns @@ -72,6 +74,7 @@ from dataset.ops import ( order_by_ds_of_columns, select_ds_of_columns, with_column_ds, + with_window_column_ds, ) from session.types import SessionError, collect_with_active_session from prism import ( @@ -86,6 +89,7 @@ from prism import ( prism_cursor_apply_order_by, prism_cursor_apply_select, prism_cursor_apply_with_column, + prism_cursor_apply_with_window_column, prism_cursor_named_table, prism_cursor_output_columns, ) @@ -103,6 +107,7 @@ pub trait DataSet[T with Clone]: def group_by(self, columns: list[ColumnExpr]) -> Self def agg(self, measures: list[AggregateMeasure]) -> Self def generate(self, generator: GeneratorApplication) -> Self + def with_window_column(self, name: str, application: WindowFunctionApplication) -> Self def order_by(self, columns: list[ColumnExpr]) -> Self def limit(self, n: int) -> Self def explode(self) -> Self @@ -218,6 +223,12 @@ pub class DataFrame[T with Clone] with BoundedDataSet: generate_ds_of_columns(self._substrait_rel, self.planned_columns(), generator), ) + def with_window_column(self, name: str, application: WindowFunctionApplication) -> Self: + """Return one new DataFrame with a named window projection stage and stale materialization cleared.""" + return _data_frame_with_invalidated_materialization( + with_window_column_ds(self._substrait_rel, self.planned_columns(), name, application), + ) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new DataFrame with an ordering stage and stale materialization cleared.""" return _data_frame_with_invalidated_materialization( @@ -303,6 +314,10 @@ pub class LazyFrame[T with Clone] with BoundedDataSet: """Return one new lazy carrier with an appended generator stage.""" return LazyFrame(_cursor=prism_cursor_apply_generate(self._cursor, generator)) + def with_window_column(self, name: str, application: WindowFunctionApplication) -> Self: + """Return one new lazy carrier with an appended named window projection stage.""" + return LazyFrame(_cursor=prism_cursor_apply_with_window_column(self._cursor, name, application)) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new lazy carrier with an appended ordering stage.""" return LazyFrame(_cursor=prism_cursor_apply_order_by(self._cursor, columns)) @@ -456,6 +471,18 @@ pub class DataStream[T with Clone] with UnboundedDataSet: ), ) + def with_window_column(self, name: str, application: WindowFunctionApplication) -> Self: + """Return one new DataStream with a named window projection stage.""" + return DataStream( + _row_schema_marker=self._row_schema_marker.clone(), + _substrait_rel=with_window_column_ds( + self._substrait_rel, + relation_output_columns(self._substrait_rel.clone()), + name, + application, + ), + ) + def order_by(self, columns: list[ColumnExpr]) -> Self: """Return one new DataStream with an ordering stage.""" return DataStream( diff --git a/src/dataset/ops.incn b/src/dataset/ops.incn index bafad30..d8c90c2 100644 --- a/src/dataset/ops.incn +++ b/src/dataset/ops.incn @@ -10,6 +10,7 @@ from rust::substrait::proto import Rel from aggregate_builders import AggregateMeasure from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, with_column_assignment +from window_builders import WindowFunctionApplication, window_projection from substrait.function_extensions import explode_extension_uri from substrait.inspect import relation_output_columns from substrait.relations import ( @@ -21,6 +22,7 @@ from substrait.relations import ( project_rel_of_columns, sort_rel_of_columns, generator_rel_of_columns, + window_rel_of_columns, ) @@ -134,6 +136,16 @@ pub def generate_ds_of_columns(rel: Rel, input_columns: list[str], generator: Ge return generator_rel_of_columns(rel, input_columns, generator) +pub def with_window_column_ds( + rel: Rel, + input_columns: list[str], + name: str, + application: WindowFunctionApplication, +) -> Rel: + """Apply one dataset-level named window projection using explicit input-column names.""" + return window_rel_of_columns(rel, input_columns, [window_projection(name, application)]) + + pub def order_by_ds(rel: Rel, columns: list[ColumnExpr]) -> Rel: """ Apply dataset-level ordering intent to one relation. diff --git a/src/functions/mod.incn b/src/functions/mod.incn index c20b662..1cfc03c 100644 --- a/src/functions/mod.incn +++ b/src/functions/mod.incn @@ -65,6 +65,10 @@ pub from functions.generators.explode import explode pub from functions.generators.explode_outer import explode_outer pub from functions.generators.posexplode import posexplode pub from functions.generators.posexplode_outer import posexplode_outer +pub from functions.windows.window import window +pub from functions.windows.row_number import row_number +pub from functions.windows.rank import rank +pub from functions.windows.dense_rank import dense_rank pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div diff --git a/src/functions/windows/dense_rank.incn b/src/functions/windows/dense_rank.incn new file mode 100644 index 0000000..2545172 --- /dev/null +++ b/src/functions/windows/dense_rank.incn @@ -0,0 +1,36 @@ +"""Dense-rank window helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry +from substrait.function_extensions import DENSE_RANK_FUNCTION_ANCHOR +from window_builders import WindowFunctionCall, dense_rank as dense_rank_builder + + +@function_registry.add("dense_rank", deterministic_spec( + FunctionClass.Window, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + extension_mapping("dense_rank", DENSE_RANK_FUNCTION_ANCHOR), +)) +pub def dense_rank() -> WindowFunctionCall: + """ + Build a dense-rank window function call. + + Examples: + ranked = dense_rank().over(window().partition_by([col("customer_id")]).order_by([desc(col("amount"))])) + """ + return dense_rank_builder() + + +module tests: + def test_dense_rank_builds_window_call() -> None: + call = dense_rank() + assert call.canonical_name == "dense_rank" + assert call.requires_ordering diff --git a/src/functions/windows/mod.incn b/src/functions/windows/mod.incn new file mode 100644 index 0000000..3c203cf --- /dev/null +++ b/src/functions/windows/mod.incn @@ -0,0 +1,6 @@ +"""Window specification and ranking helper functions.""" + +pub from functions.windows.window import window +pub from functions.windows.row_number import row_number +pub from functions.windows.rank import rank +pub from functions.windows.dense_rank import dense_rank diff --git a/src/functions/windows/rank.incn b/src/functions/windows/rank.incn new file mode 100644 index 0000000..54f7ff0 --- /dev/null +++ b/src/functions/windows/rank.incn @@ -0,0 +1,36 @@ +"""Rank window helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry +from substrait.function_extensions import RANK_FUNCTION_ANCHOR +from window_builders import WindowFunctionCall, rank as rank_builder + + +@function_registry.add("rank", deterministic_spec( + FunctionClass.Window, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + extension_mapping("rank", RANK_FUNCTION_ANCHOR), +)) +pub def rank() -> WindowFunctionCall: + """ + Build a rank window function call. + + Examples: + ranked = rank().over(window().partition_by([col("customer_id")]).order_by([desc(col("amount"))])) + """ + return rank_builder() + + +module tests: + def test_rank_builds_window_call() -> None: + call = rank() + assert call.canonical_name == "rank" + assert call.requires_ordering diff --git a/src/functions/windows/row_number.incn b/src/functions/windows/row_number.incn new file mode 100644 index 0000000..f22ee64 --- /dev/null +++ b/src/functions/windows/row_number.incn @@ -0,0 +1,36 @@ +"""Row-number window helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + extension_mapping, + v0_1, +) +from functions.registry import function_registry +from substrait.function_extensions import ROW_NUMBER_FUNCTION_ANCHOR +from window_builders import WindowFunctionCall, row_number as row_number_builder + + +@function_registry.add("row_number", deterministic_spec( + FunctionClass.Window, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + extension_mapping("row_number", ROW_NUMBER_FUNCTION_ANCHOR), +)) +pub def row_number() -> WindowFunctionCall: + """ + Build a row-number window function call. + + Examples: + numbered = row_number().over(window().partition_by([col("customer_id")]).order_by([col("created_at")])) + """ + return row_number_builder() + + +module tests: + def test_row_number_builds_window_call() -> None: + call = row_number() + assert call.canonical_name == "row_number" + assert call.requires_ordering diff --git a/src/functions/windows/window.incn b/src/functions/windows/window.incn new file mode 100644 index 0000000..7bcdc18 --- /dev/null +++ b/src/functions/windows/window.incn @@ -0,0 +1,35 @@ +"""Window specification builder helper.""" + +from function_registry import ( + FunctionClass, + FunctionLifecycle, + FunctionNullBehavior, + deterministic_spec, + structural_mapping, + v0_1, +) +from functions.registry import function_registry +from window_builders import WindowSpec, window as window_builder + + +@function_registry.add("window", deterministic_spec( + FunctionClass.Window, + FunctionLifecycle(since=v0_1, changed=[], deprecated=None), + FunctionNullBehavior.NonNullLiteral, + structural_mapping("window_spec"), +)) +pub def window() -> WindowSpec: + """ + Build an empty window specification. + + Examples: + spec = window().partition_by([col("customer_id")]).order_by([col("created_at")]) + """ + return window_builder() + + +module tests: + def test_window_builds_empty_window_spec() -> None: + spec = window() + assert len(spec.partition_columns) == 0 + assert len(spec.sort_columns) == 0 diff --git a/src/lib.incn b/src/lib.incn index 87604d7..a707823 100644 --- a/src/lib.incn +++ b/src/lib.incn @@ -16,6 +16,7 @@ pub from dataset.ops import ( limit_ds, order_by_ds, select_ds, + with_window_column_ds, ) pub from aggregate_builders import AggregateKind, AggregateMeasure pub from generator_builders import ( @@ -24,6 +25,15 @@ pub from generator_builders import ( generator_output_columns, generator_primary_output_column, ) +pub from window_builders import ( + WindowFunctionApplication, + WindowFunctionCall, + WindowFunctionKind, + WindowProjection, + WindowSpec, + window_output_columns, + window_projection, +) pub from projection_builders import ( BoolLiteralExpr, ColumnExpr, @@ -101,6 +111,10 @@ pub from functions.generators.explode import explode pub from functions.generators.explode_outer import explode_outer pub from functions.generators.posexplode import posexplode pub from functions.generators.posexplode_outer import posexplode_outer +pub from functions.windows.window import window +pub from functions.windows.row_number import row_number +pub from functions.windows.rank import rank +pub from functions.windows.dense_rank import dense_rank pub from functions.operators.add import add pub from functions.operators.and_ import and_ pub from functions.operators.div import div @@ -218,6 +232,8 @@ pub from substrait.relations import ( set_rel_of_kind, sort_rel, sort_rel_of_columns, + window_rel, + window_rel_of_columns, ) pub from substrait.plans import ( empty_plan, @@ -244,6 +260,9 @@ pub from substrait.inspect import ( rel_contains_kind, root_rel, set_operation_name, + window_function_names, + window_partition_count, + window_sort_count, ) pub from substrait.function_extensions import ( explode_extension_uri, @@ -251,6 +270,9 @@ pub from substrait.function_extensions import ( function_extension_uri, posexplode_extension_uri, posexplode_outer_extension_uri, + DENSE_RANK_FUNCTION_ANCHOR, + RANK_FUNCTION_ANCHOR, + ROW_NUMBER_FUNCTION_ANCHOR, registered_substrait_extension_uris, ) pub from substrait.conformance_catalog import ( diff --git a/src/prism/lower.incn b/src/prism/lower.incn index 6020b57..078a2cd 100644 --- a/src/prism/lower.incn +++ b/src/prism/lower.incn @@ -15,6 +15,7 @@ from substrait.relations import ( try_aggregate_rel_of_columns, try_filter_rel_of_columns, try_project_rel_of_columns, + try_window_rel_of_columns, ) from substrait.errors import SubstraitLoweringError from prism.rewrite import derive_rewritten_view, rewritten_node_at @@ -125,6 +126,12 @@ def _try_lower_node(view: PrismOptimizedView, node_id: int) -> Result[Rel, Subst rewritten_output_columns(view, node.input_ids[0]), node.generator_applications[0], ) + PrismNodeKind.Window => + return try_window_rel_of_columns( + _try_lower_node(view, node.input_ids[0])?, + rewritten_output_columns(view, node.input_ids[0]), + node.window_projections, + ) PrismNodeKind.OrderBy => return Ok( sort_rel_of_columns( diff --git a/src/prism/mod.incn b/src/prism/mod.incn index 3564d35..09d2933 100644 --- a/src/prism/mod.incn +++ b/src/prism/mod.incn @@ -15,6 +15,7 @@ from aggregate_builders import AggregateMeasure from filter_builders import always_true from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, with_column_assignment +from window_builders import WindowFunctionApplication, window_projection from prism.lower import ( lower_prism_tip as lower_prism_tip_impl, prism_rel_to_plan, @@ -71,6 +72,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -90,6 +92,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -106,6 +109,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -124,6 +128,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -142,6 +147,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[with_column_assignment(name, expr)], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -160,6 +166,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -178,6 +185,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=measures, generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -196,6 +204,7 @@ pub class PrismCursor[T with Clone]: sort_columns=columns, aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -214,6 +223,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -232,6 +242,7 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -250,6 +261,26 @@ pub class PrismCursor[T with Clone]: sort_columns=[], aggregate_measures=[], generator_applications=[generator], + window_projections=[], + projection_assignments=[], + ) + return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) + + def with_window_column(self, name: str, application: WindowFunctionApplication) -> Self: + """Append one named window projection and return the derived tip.""" + next_tip_id = append_node( + store_id=self.store_id, + kind=PrismNodeKind.Window, + input_ids=[self.tip_id], + named_table="", + join_predicate=false, + filter_predicate=always_true(), + limit_count=0, + group_columns=[], + sort_columns=[], + aggregate_measures=[], + generator_applications=[], + window_projections=[window_projection(name, application)], projection_assignments=[], ) return PrismCursor(store_id=self.store_id, tip_id=next_tip_id, _type_marker=[]) @@ -294,6 +325,7 @@ pub def prism_cursor_named_table[T with Clone](table_name: str) -> PrismCursor[T sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) return PrismCursor(store_id=store_id, tip_id=tip_id, _type_marker=[]) @@ -363,6 +395,15 @@ pub def prism_cursor_apply_generate[T with Clone]( return cursor.generate(generator) +pub def prism_cursor_apply_with_window_column[T with Clone]( + cursor: PrismCursor[T], + name: str, + application: WindowFunctionApplication, +) -> PrismCursor[T]: + """Apply one named window projection through Prism.""" + return cursor.with_window_column(name, application) + + pub def prism_cursor_output_columns[T with Clone](cursor: PrismCursor[T]) -> list[str]: """Return plan-time output columns for one cursor tip.""" return cursor.planned_columns() diff --git a/src/prism/output_columns.incn b/src/prism/output_columns.incn index d1cfa06..6c88c6a 100644 --- a/src/prism/output_columns.incn +++ b/src/prism/output_columns.incn @@ -7,6 +7,7 @@ from generator_builders import generator_output_columns from projection_builders import ColumnExpr, project_output_columns, scalar_expr_output_name from substrait.inspect import aggregate_measure_output_names from substrait.schema_registry import named_table_columns +from window_builders import window_output_columns def _is_passthrough_output_kind(kind: PrismNodeKind) -> bool: @@ -33,6 +34,8 @@ pub def authored_output_columns(store_id: PrismStoreId, tip_id: int) -> list[str authored_output_columns(store_id, node.input_ids[0]), node.generator_applications[0], ) + if node.kind == PrismNodeKind.Window: + return window_output_columns(authored_output_columns(store_id, node.input_ids[0]), node.window_projections) if node.kind == PrismNodeKind.Join: # Join output columns preserve the conventional left-then-right relation order. # We keep both sides verbatim here; duplicate names are part of the current output shape and are resolved later @@ -70,6 +73,8 @@ pub def rewritten_output_columns(view: PrismOptimizedView, node_id: int) -> list rewritten_output_columns(view, node.input_ids[0]), node.generator_applications[0], ) + if node.kind == PrismNodeKind.Window: + return window_output_columns(rewritten_output_columns(view, node.input_ids[0]), node.window_projections) if node.kind == PrismNodeKind.Join: # Rewritten views keep the same left-then-right join column order as authored views # so output-column inference stays stable across Prism rewrite passes. diff --git a/src/prism/rewrite.incn b/src/prism/rewrite.incn index 419f968..12095fd 100644 --- a/src/prism/rewrite.incn +++ b/src/prism/rewrite.incn @@ -169,6 +169,7 @@ def _build_collapsed_limit_node( sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) @@ -206,6 +207,7 @@ def _build_collapsed_project_node( sort_columns=[], aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=merged_assignments, ) @@ -243,6 +245,7 @@ def _build_collapsed_aggregate_node( sort_columns=[], aggregate_measures=merged_measures, generator_applications=[], + window_projections=[], projection_assignments=[], ) @@ -278,6 +281,7 @@ def _build_collapsed_order_by_node( sort_columns=node.sort_columns, aggregate_measures=[], generator_applications=[], + window_projections=[], projection_assignments=[], ) @@ -296,6 +300,7 @@ def _build_rewritten_node(node: PrismNode, remapped_inputs: list[int], rewritten sort_columns=node.sort_columns, aggregate_measures=node.aggregate_measures, generator_applications=node.generator_applications, + window_projections=node.window_projections, projection_assignments=node.projection_assignments, ) @@ -342,6 +347,7 @@ def _compact_optimized_view(view: PrismOptimizedView) -> PrismOptimizedView: sort_columns=old_node.sort_columns, aggregate_measures=old_node.aggregate_measures, generator_applications=old_node.generator_applications, + window_projections=old_node.window_projections, projection_assignments=old_node.projection_assignments, ), ) diff --git a/src/prism/store.incn b/src/prism/store.incn index e451ade..e7af5c3 100644 --- a/src/prism/store.incn +++ b/src/prism/store.incn @@ -2,6 +2,7 @@ from aggregate_builders import AggregateMeasure from generator_builders import GeneratorApplication +from window_builders import WindowFunctionApplication, WindowProjection, WindowSpec from projection_builders import ( BoolLiteralExpr, ColumnExpr, @@ -56,6 +57,7 @@ pub def append_node( sort_columns: list[ColumnExpr], aggregate_measures: list[AggregateMeasure], generator_applications: list[GeneratorApplication], + window_projections: list[WindowProjection], projection_assignments: list[ProjectionAssignment], ) -> int: """ @@ -76,6 +78,7 @@ pub def append_node( sort_columns=sort_columns, aggregate_measures=aggregate_measures, generator_applications=generator_applications, + window_projections=window_projections, projection_assignments=projection_assignments, ) prism_stored_nodes.append(PrismStoredNode(store_id_raw=store_id.0, node=appended)) @@ -123,11 +126,13 @@ pub def adopt_cursor_subgraph( adopted_sort_columns = [column for column in source_node.sort_columns] adopted_measures = [measure for measure in source_node.aggregate_measures] adopted_generators = [generator for generator in source_node.generator_applications] + adopted_windows = [projection for projection in source_node.window_projections] adopted_assignments = [assignment for assignment in source_node.projection_assignments] target_group_columns = [column for column in source_node.group_columns] target_sort_columns = [column for column in source_node.sort_columns] target_measures = [measure for measure in source_node.aggregate_measures] target_generators = [generator for generator in source_node.generator_applications] + target_windows = [projection for projection in source_node.window_projections] target_assignments = [assignment for assignment in source_node.projection_assignments] adopted_id = append_node( store_id=target_store_id, @@ -141,6 +146,7 @@ pub def adopt_cursor_subgraph( sort_columns=adopted_sort_columns, aggregate_measures=adopted_measures, generator_applications=adopted_generators, + window_projections=adopted_windows, projection_assignments=adopted_assignments, ) target_store_nodes.append( @@ -156,6 +162,7 @@ pub def adopt_cursor_subgraph( sort_columns=target_sort_columns, aggregate_measures=target_measures, generator_applications=target_generators, + window_projections=target_windows, projection_assignments=target_assignments, ), ) @@ -244,6 +251,8 @@ def _nodes_structurally_equal(candidate: PrismNode, source_node: PrismNode, rema source_node.generator_applications, ): return false + if not _window_projection_lists_structurally_equal(candidate.window_projections, source_node.window_projections): + return false if not _projection_assignments_structurally_equal( candidate.projection_assignments, source_node.projection_assignments, @@ -315,6 +324,45 @@ def _generator_applications_structurally_equal(left: GeneratorApplication, right return _column_exprs_structurally_equal(left.expr, right.expr) +def _window_projection_lists_structurally_equal(left: list[WindowProjection], right: list[WindowProjection]) -> bool: + """Return whether two window projection lists carry identical output names and applications.""" + if len(left) != len(right): + return false + for idx in range(len(left)): + if not _window_projections_structurally_equal(left[idx], right[idx]): + return false + return true + + +def _window_projections_structurally_equal(left: WindowProjection, right: WindowProjection) -> bool: + """Return whether two window projections carry identical names and window semantics.""" + if left.output_name != right.output_name: + return false + return _window_applications_structurally_equal(left.application, right.application) + + +def _window_applications_structurally_equal(left: WindowFunctionApplication, right: WindowFunctionApplication) -> bool: + """Return whether two window applications carry identical registry identity and window specs.""" + if left.kind != right.kind: + return false + if left.function_ref != right.function_ref: + return false + if left.canonical_name != right.canonical_name: + return false + if left.requires_ordering != right.requires_ordering: + return false + if not _column_expr_lists_structurally_equal(left.arguments, right.arguments): + return false + return _window_specs_structurally_equal(left.spec, right.spec) + + +def _window_specs_structurally_equal(left: WindowSpec, right: WindowSpec) -> bool: + """Return whether two window specs carry identical partitioning and ordering.""" + if not _column_expr_lists_structurally_equal(left.partition_columns, right.partition_columns): + return false + return _column_expr_lists_structurally_equal(left.sort_columns, right.sort_columns) + + def _text_lists_structurally_equal(left: list[str], right: list[str]) -> bool: """Return whether two string lists are structurally equivalent.""" if len(left) != len(right): diff --git a/src/prism/types.incn b/src/prism/types.incn index 59472c1..666b266 100644 --- a/src/prism/types.incn +++ b/src/prism/types.incn @@ -3,6 +3,7 @@ from aggregate_builders import AggregateMeasure from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment +from window_builders import WindowProjection pub type PrismStoreId = newtype int @@ -19,6 +20,7 @@ pub enum PrismNodeKind(str): GroupBy = "GroupBy" Aggregate = "Aggregate" Generate = "Generate" + Window = "Window" OrderBy = "OrderBy" Limit = "Limit" Explode = "Explode" @@ -44,6 +46,7 @@ pub model PrismNode: pub sort_columns: list[ColumnExpr] pub aggregate_measures: list[AggregateMeasure] pub generator_applications: list[GeneratorApplication] + pub window_projections: list[WindowProjection] pub projection_assignments: list[ProjectionAssignment] diff --git a/src/substrait/extensions.incn b/src/substrait/extensions.incn index 4f0edad..bc2b60b 100644 --- a/src/substrait/extensions.incn +++ b/src/substrait/extensions.incn @@ -58,6 +58,11 @@ pub def scalar_function_name_from_anchor(anchor: u32) -> str: return _function_name_from_anchor_or_raise(anchor, ExtensionFunctionKind.Scalar) +pub def window_function_name_from_anchor(anchor: u32) -> str: + """Resolve one known window-function anchor back to its registered function name.""" + return _function_name_from_anchor_or_raise(anchor, ExtensionFunctionKind.Window) + + def _function_extension_specs() -> list[FunctionExtensionSpec]: """Return Substrait extension specs derived from declaration-side registry metadata.""" mut specs: list[FunctionExtensionSpec] = [] @@ -81,6 +86,14 @@ def _function_extension_specs() -> list[FunctionExtensionSpec]: kind=ExtensionFunctionKind.Scalar, ), ) + elif entry.function_class == FunctionClass.Window: + specs.append( + FunctionExtensionSpec( + anchor=entry.substrait.anchor, + name=entry.substrait.function_name, + kind=ExtensionFunctionKind.Window, + ), + ) return specs @@ -145,6 +158,15 @@ def _scalar_extension_decl(anchor: u32) -> SimpleExtensionDeclaration: return raise_value_error(message) +def _window_extension_decl(anchor: u32) -> SimpleExtensionDeclaration: + """Build one window-function extension declaration for the provided anchor.""" + match _function_spec_from_anchor_of_kind(anchor, ExtensionFunctionKind.Window): + Ok(spec) => return _function_extension_decl(spec) + Err(err) => + message = err.error_message() + return raise_value_error(message) + + def _expr_uses_scalar_function_anchor(expr: Expression, expected_anchor: u32) -> bool: """Return whether one expression tree uses the requested scalar-function anchor.""" match expr.rex_type: @@ -264,6 +286,19 @@ def _rel_uses_aggregate_function_anchor(rel: Rel, expected_anchor: u32) -> bool: return false +def _rel_uses_window_function_anchor(rel: Rel, expected_anchor: u32) -> bool: + """Return whether one relation subtree uses the requested window-function anchor.""" + if let Some(RelType.Window(window_rel)) = rel.rel_type.clone(): + for window_function in window_rel.window_functions: + if window_function.function_reference == expected_anchor: + return true + + for child in relation_children(rel): + if _rel_uses_window_function_anchor(child, expected_anchor): + return true + return false + + def _rel_uses_scalar_function_anchor(rel: Rel, expected_anchor: u32) -> bool: """Return whether one relation subtree uses the requested scalar-function anchor.""" match rel.rel_type.clone(): @@ -336,6 +371,17 @@ def _aggregate_extension_anchors_for_rel(rel: Rel) -> list[u32]: return anchors +def _window_extension_anchors_for_rel(rel: Rel) -> list[u32]: + """Collect window-function anchors used by one relation subtree in stable declaration order.""" + mut anchors: list[u32] = [] + for spec in _function_extension_specs(): + if (spec.kind == ExtensionFunctionKind.Window and _rel_uses_window_function_anchor(rel.clone(), spec.anchor) and not anchors.contains( + spec.anchor, + )): + anchors.append(spec.anchor) + return anchors + + def _scalar_extension_anchors_for_rel(rel: Rel) -> list[u32]: """Collect scalar-function anchors used by one relation subtree in stable declaration order.""" if _rel_uses_if_then(rel.clone()): @@ -362,8 +408,13 @@ def _scalar_extension_decls(anchors: list[u32]) -> list[SimpleExtensionDeclarati return [_scalar_extension_decl(anchor) for anchor in anchors] +def _window_extension_decls(anchors: list[u32]) -> list[SimpleExtensionDeclaration]: + """Lower window-function anchors into extension declarations in the provided order.""" + return [_window_extension_decl(anchor) for anchor in anchors] + + def _extension_decl_for_anchor(anchor: u32) -> SimpleExtensionDeclaration: - """Lower one known aggregate or scalar function anchor into its extension declaration.""" + """Lower one known aggregate, window, or scalar function anchor into its extension declaration.""" match _function_spec_from_anchor(anchor): Ok(spec) => return _function_extension_decl(spec) Err(err) => @@ -372,8 +423,9 @@ def _extension_decl_for_anchor(anchor: u32) -> SimpleExtensionDeclaration: def _plan_extension_anchors_for_rel(rel: Rel) -> list[u32]: - """Collect aggregate and scalar function anchors used by one relation subtree in stable plan order.""" + """Collect function anchors used by one relation subtree in stable plan order.""" mut anchors = _aggregate_extension_anchors_for_rel(rel.clone()) + anchors.extend(_window_extension_anchors_for_rel(rel.clone())) anchors.extend(_scalar_extension_anchors_for_rel(rel)) return anchors diff --git a/src/substrait/function_extensions.incn b/src/substrait/function_extensions.incn index 649a680..72e5d5f 100644 --- a/src/substrait/function_extensions.incn +++ b/src/substrait/function_extensions.incn @@ -12,6 +12,7 @@ pub enum ExtensionFunctionKind(str): Aggregate = "aggregate" Scalar = "scalar" + Window = "window" @derive(Clone) @@ -75,6 +76,9 @@ pub const MAP_EXTRACT_FUNCTION_ANCHOR: u32 = 48 pub const NAMED_STRUCT_FUNCTION_ANCHOR: u32 = 49 pub const ARRAY_HAS_ANY_FUNCTION_ANCHOR: u32 = 50 pub const ARRAY_FLATTEN_FUNCTION_ANCHOR: u32 = 51 +pub const ROW_NUMBER_FUNCTION_ANCHOR: u32 = 52 +pub const RANK_FUNCTION_ANCHOR: u32 = 53 +pub const DENSE_RANK_FUNCTION_ANCHOR: u32 = 54 const FUNCTION_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/functions.yaml" const EXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode" const EXPLODE_OUTER_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode_outer" diff --git a/src/substrait/inspect.incn b/src/substrait/inspect.incn index 063061f..7e23b48 100644 --- a/src/substrait/inspect.incn +++ b/src/substrait/inspect.incn @@ -7,7 +7,7 @@ inspection utilities used by tests, dataset carriers, and conformance validation from rust::std::boxed import Box from rust::std::primitive import i32 as RustI32 -from rust::substrait::proto import AggregateRel, Expression, Plan, ReadRel, Rel, RelCommon +from rust::substrait::proto import AggregateRel, ConsistentPartitionWindowRel, Expression, Plan, ReadRel, Rel, RelCommon from rust::substrait::proto::aggregate_rel import Measure from rust::substrait::proto::aggregate_function import AggregationInvocation from rust::substrait::proto::function_argument import ArgType @@ -20,7 +20,7 @@ from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateKind, AggregateMeasure from projection_builders import scalar_expr_output_name from substrait.expr_lowering import field_index_from_expression, project_output_columns_with_emit, rust_u32_to_int -from substrait.extensions import aggregate_function_name_from_anchor +from substrait.extensions import aggregate_function_name_from_anchor, window_function_name_from_anchor from substrait.function_extensions import ( explode_extension_uri, explode_outer_extension_uri, @@ -158,6 +158,30 @@ def _aggregate_output_columns(aggregate_rel: AggregateRel) -> list[str]: return columns +def _window_input_columns(window_rel: ConsistentPartitionWindowRel) -> list[str]: + """Resolve the current input-column names feeding one window relation.""" + match window_rel.input: + Some(child) => return relation_output_columns(child.as_ref().clone()) + None => return [] + + +def _window_output_columns(window_rel: ConsistentPartitionWindowRel) -> list[str]: + """Return input-column names followed by best-effort lowered window output names.""" + mut columns = _window_input_columns(window_rel.clone()) + for idx, window_function in enumerate(window_rel.window_functions): + function_name = window_function_name_from_anchor(window_function.function_reference) + if len(function_name) > 0: + columns.append(function_name) + else: + columns.append(f"window_{idx}") + return columns + + +def _window_function_names_for_rel(window_rel: ConsistentPartitionWindowRel) -> list[str]: + """Return lowered window-function names in declaration order.""" + return [window_function_name_from_anchor(fun.function_reference) for fun in window_rel.window_functions] + + def _relation_output_columns(rel: Rel) -> list[str]: """Return the current best-effort output column names for one relation subtree.""" match rel.rel_type.clone(): @@ -220,6 +244,7 @@ def _relation_output_columns(rel: Rel) -> list[str]: return [] return _relation_output_columns(set_rel.inputs[0]) Some(RelType.Aggregate(aggregate_rel)) => return _aggregate_output_columns(aggregate_rel.as_ref().clone()) + Some(RelType.Window(window_rel)) => return _window_output_columns(window_rel.as_ref().clone()) _ => return [] @@ -228,6 +253,27 @@ pub def relation_output_columns(rel: Rel) -> list[str]: return _relation_output_columns(rel) +pub def window_function_names(rel: Rel) -> list[str]: + """Return lowered window-function names when `rel` is a window root.""" + match rel.rel_type: + Some(RelType.Window(window_rel)) => return _window_function_names_for_rel(window_rel.as_ref().clone()) + _ => return [] + + +pub def window_partition_count(rel: Rel) -> int: + """Return the number of partition expressions when `rel` is a window root.""" + match rel.rel_type: + Some(RelType.Window(window_rel)) => return len(window_rel.partition_expressions) + _ => return 0 + + +pub def window_sort_count(rel: Rel) -> int: + """Return the number of sort expressions when `rel` is a window root.""" + match rel.rel_type: + Some(RelType.Window(window_rel)) => return len(window_rel.sorts) + _ => return 0 + + def _extension_single_output_columns(input_columns: list[str], extension_uri: str) -> list[str]: """Return best-effort output columns for known extension-single relation encodings.""" mut columns: list[str] = [] @@ -332,6 +378,7 @@ pub def relation_kind_name(rel: Rel) -> str: Some(RelType.Join(_)) => return "JoinRel" Some(RelType.Cross(_)) => return "CrossRel" Some(RelType.Aggregate(_)) => return "AggregateRel" + Some(RelType.Window(_)) => return "ConsistentPartitionWindowRel" Some(RelType.Sort(_)) => return "SortRel" Some(RelType.Fetch(_)) => return "FetchRel" Some(RelType.Set(_)) => return "SetRel" diff --git a/src/substrait/relations.incn b/src/substrait/relations.incn index b075e5f..9ae957c 100644 --- a/src/substrait/relations.incn +++ b/src/substrait/relations.incn @@ -13,6 +13,7 @@ from rust::std::primitive import i32 as RustI32 from rust::substrait::proto import ( AggregateFunction, AggregateRel, + ConsistentPartitionWindowRel, CrossRel, ExtensionSingleRel, Expression, @@ -32,6 +33,7 @@ from rust::substrait::proto import ( ) from rust::substrait::proto::aggregate_function import AggregationInvocation from rust::substrait::proto::aggregate_rel import Grouping, Measure +from rust::substrait::proto::consistent_partition_window_rel import WindowRelFunction from rust::substrait::proto::expression::nested import Struct as NestedStruct from rust::substrait::proto::fetch_rel import CountMode, OffsetMode from rust::substrait::proto::function_argument import ArgType @@ -48,6 +50,7 @@ from function_registry import FunctionClass, FunctionRegistryEntry, SubstraitMap from functions.registry import function_registry_entry from generator_builders import GeneratorApplication from projection_builders import ColumnExpr, ProjectionAssignment, ScalarFunctionApplicationExpr, col +from window_builders import WindowFunctionApplication, WindowProjection from substrait.expr_lowering import ( bool_expr, filter_predicate_expr, @@ -91,6 +94,18 @@ model ResolvedGeneratorApplication: expr: Expression +@derive(Clone) +model ResolvedWindowProjection: + """One named window projection resolved against input columns and registry metadata.""" + + output_name: str + application: WindowFunctionApplication + entry: FunctionRegistryEntry + arguments: list[FunctionArgument] + partition_expressions: list[Expression] + sorts: list[SortField] + + pub enum SubstraitJoinKind: Inner Left @@ -146,6 +161,11 @@ def _rel_aggregate(aggregate: AggregateRel) -> Rel: return Rel(rel_type=Some(RelType.Aggregate(Box.new(aggregate)))) +def _rel_window(window: ConsistentPartitionWindowRel) -> Rel: + """Wrap ConsistentPartitionWindowRel payload into one generic Rel union value.""" + return Rel(rel_type=Some(RelType.Window(Box.new(window)))) + + def _rel_sort(sort: SortRel) -> Rel: """Wrap SortRel payload into one generic Rel union value.""" return Rel(rel_type=Some(RelType.Sort(Box.new(sort)))) @@ -316,6 +336,64 @@ def _validate_generator_output_columns( return Ok(None) +def _window_registry_entry( + application: WindowFunctionApplication, +) -> Result[FunctionRegistryEntry, SubstraitLoweringError]: + """Resolve one window registry entry and validate its semantic class.""" + match function_registry_entry(application.function_ref): + Some(entry) => + if entry.function_class != FunctionClass.Window: + return Err(invalid_scalar_expression(f"{entry.function_ref} is not registered as a window function")) + if entry.substrait.kind != SubstraitMappingKind.ExtensionFunction: + return Err( + invalid_scalar_expression(f"{entry.function_ref} does not declare a window extension mapping"), + ) + return Ok(entry) + None => + return Err(invalid_scalar_expression(f"missing window registry entry for `{application.canonical_name}`")) + + +def _resolved_window_projection( + input_columns: list[str], + projection: WindowProjection, +) -> Result[ResolvedWindowProjection, SubstraitLoweringError]: + """Resolve one window projection against input-column names.""" + if len(projection.output_name) == 0: + return Err(invalid_scalar_expression("window output alias must be non-empty")) + application = projection.application + if application.requires_ordering and len(application.spec.sort_columns) == 0: + return Err(invalid_scalar_expression(f"{application.function_ref} requires an explicit window ordering")) + return Ok( + ResolvedWindowProjection( + output_name=projection.output_name, + application=application.clone(), + entry=_window_registry_entry(application.clone())?, + arguments=[FunctionArgument(arg_type=Some(ArgType.Value(scalar_expr(input_columns, arg)?))) for arg in application.arguments], + partition_expressions=[scalar_expr(input_columns, column)? for column in application.spec.partition_columns], + sorts=[_sort_field(input_columns, column)? for column in application.spec.sort_columns], + ), + ) + + +def _resolved_window_projection_to_substrait( + projection: ResolvedWindowProjection, +) -> Result[WindowRelFunction, SubstraitLoweringError]: + """Lower one resolved window projection into a Substrait window function payload.""" + return Ok( + WindowRelFunction( + function_reference=projection.entry.substrait.anchor, + arguments=projection.arguments, + options=[], + output_type=None, + phase=AggregationPhase.InitialToResult.into(), + invocation=AggregationInvocation.All.into(), + lower_bound=None, + upper_bound=None, + bounds_type=0, + ), + ) + + def _contains_text(values: list[str], expected: str) -> bool: """Return whether a string list contains a value.""" for value in values: @@ -693,6 +771,46 @@ pub def try_generator_rel_of_columns( return Ok(extension_single_rel(input, resolved.entry.substrait.uri)) +pub def window_rel(input: Rel, projection: WindowProjection) -> Rel: + """Wrap a child relation in a window relation with one named window projection.""" + return _lowered_rel_or_raise(try_window_rel(input, projection)) + + +pub def try_window_rel(input: Rel, projection: WindowProjection) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in a window relation with one named window projection.""" + return try_window_rel_of_columns(input.clone(), relation_output_columns(input), [projection]) + + +pub def window_rel_of_columns(input: Rel, input_columns: list[str], projections: list[WindowProjection]) -> Rel: + """Wrap a child relation in a window relation using explicit input-column names.""" + return _lowered_rel_or_raise(try_window_rel_of_columns(input, input_columns, projections)) + + +pub def try_window_rel_of_columns( + input: Rel, + input_columns: list[str], + projections: list[WindowProjection], +) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in a window relation using explicit input-column names.""" + if len(projections) == 0: + return Err(invalid_scalar_expression("window relation requires at least one window projection")) + if len(projections) > 1: + return Err(invalid_scalar_expression("window relation currently accepts exactly one window projection")) + resolved = _resolved_window_projection(input_columns, projections[0])? + return Ok( + _rel_window( + ConsistentPartitionWindowRel( + common=Some(_direct_common()), + input=Some(Box.new(input)), + window_functions=[_resolved_window_projection_to_substrait(resolved.clone())?], + partition_expressions=resolved.partition_expressions, + sorts=resolved.sorts, + advanced_extension=None, + ), + ), + ) + + pub def sort_rel(input: Rel) -> Rel: """Wrap a child relation in `SortRel` using the first known output column as the default sort key.""" input_columns = relation_output_columns(input.clone()) diff --git a/src/substrait/traversal.incn b/src/substrait/traversal.incn index 93ea530..f7bbcec 100644 --- a/src/substrait/traversal.incn +++ b/src/substrait/traversal.incn @@ -37,6 +37,10 @@ pub def relation_children(rel: Rel) -> list[Rel]: match aggregate.input: Some(child) => return [child.as_ref().clone()] None => return [] + Some(RelType.Window(window)) => + match window.input: + Some(child) => return [child.as_ref().clone()] + None => return [] Some(RelType.Sort(sort)) => match sort.input: Some(child) => return [child.as_ref().clone()] diff --git a/src/window_builders.incn b/src/window_builders.incn new file mode 100644 index 0000000..5e9f8ea --- /dev/null +++ b/src/window_builders.incn @@ -0,0 +1,152 @@ +""" +Window specification and window-function builder surface. + +Window applications are relation-aware expressions: they produce one value per input row while reading a partition of +related rows. They intentionally do not reuse scalar expression nodes, so invalid scalar positions can remain +diagnosable as the query surface grows. +""" + +from rust::incan_stdlib::errors import raise_value_error +from function_registry import function_ref_for +from projection_builders import ColumnExpr + + +@derive(Clone) +pub enum WindowFunctionKind(str): + """Supported window function kinds in the current foundation slice.""" + + RowNumber = "row_number" + Rank = "rank" + DenseRank = "dense_rank" + + +@derive(Clone) +pub model WindowSpec: + """Partitioning and ordering shared by one or more window function applications.""" + + pub partition_columns: list[ColumnExpr] + pub sort_columns: list[ColumnExpr] + + def partition_by(self, columns: list[ColumnExpr]) -> Self: + """Return this window spec with partition expressions replaced.""" + return WindowSpec(partition_columns=columns, sort_columns=self.sort_columns) + + def order_by(self, columns: list[ColumnExpr]) -> Self: + """Return this window spec with ordering expressions replaced.""" + return WindowSpec(partition_columns=self.partition_columns, sort_columns=columns) + + +@derive(Clone) +pub model WindowFunctionApplication: + """One placed window function application.""" + + pub kind: WindowFunctionKind + pub function_ref: str + pub canonical_name: str + pub arguments: list[ColumnExpr] + pub spec: WindowSpec + pub requires_ordering: bool + + +@derive(Clone) +pub model WindowFunctionCall: + """Unplaced window function call waiting for an explicit window specification.""" + + pub kind: WindowFunctionKind + pub function_ref: str + pub canonical_name: str + pub arguments: list[ColumnExpr] + pub requires_ordering: bool + + def over(self, spec: WindowSpec) -> WindowFunctionApplication: + """Place this window function call over a concrete window specification.""" + return WindowFunctionApplication( + kind=self.kind, + function_ref=self.function_ref, + canonical_name=self.canonical_name, + arguments=self.arguments, + spec=spec, + requires_ordering=self.requires_ordering, + ) + + +@derive(Clone) +pub model WindowProjection: + """One named output column backed by a placed window function application.""" + + pub output_name: str + pub application: WindowFunctionApplication + + +pub def window() -> WindowSpec: + """Build an empty window specification.""" + return WindowSpec(partition_columns=[], sort_columns=[]) + + +pub def row_number() -> WindowFunctionCall: + """Build a `row_number` window function call.""" + return _ranking_call("row_number", WindowFunctionKind.RowNumber) + + +pub def rank() -> WindowFunctionCall: + """Build a `rank` window function call.""" + return _ranking_call("rank", WindowFunctionKind.Rank) + + +pub def dense_rank() -> WindowFunctionCall: + """Build a `dense_rank` window function call.""" + return _ranking_call("dense_rank", WindowFunctionKind.DenseRank) + + +pub def window_projection(output_name: str, application: WindowFunctionApplication) -> WindowProjection: + """Build one named window projection after validating its output alias.""" + if len(output_name) == 0: + return raise_value_error("window output alias must be non-empty") + return WindowProjection(output_name=output_name, application=application) + + +pub def window_output_columns(input_columns: list[str], projections: list[WindowProjection]) -> list[str]: + """Return output columns after applying window projections with add-or-replace semantics.""" + mut output_columns: list[str] = [] + output_columns.extend(input_columns) + for projection in projections: + existing_idx = _index_of_text(output_columns, projection.output_name) + if existing_idx >= 0: + output_columns[existing_idx] = projection.output_name + else: + output_columns.append(projection.output_name) + return output_columns + + +def _ranking_call(canonical_name: str, kind: WindowFunctionKind) -> WindowFunctionCall: + """Build one ranking-family window call.""" + return WindowFunctionCall( + kind=kind, + function_ref=function_ref_for(canonical_name), + canonical_name=canonical_name, + arguments=[], + requires_ordering=true, + ) + + +def _index_of_text(values: list[str], expected: str) -> int: + """Return the index of a string value, or -1 when absent.""" + for idx, value in enumerate(values): + if value == expected: + return idx + return -1 + + +module tests: + from projection_builders import col, column_expr_name + def test_window_spec_builders_preserve_partition_and_order() -> None: + spec = window().partition_by([col("customer_id")]).order_by([col("amount")]) + assert len(spec.partition_columns) == 1 + assert column_expr_name(spec.partition_columns[0]) == "customer_id" + assert len(spec.sort_columns) == 1 + assert column_expr_name(spec.sort_columns[0]) == "amount" + def test_ranking_call_over_records_registry_identity() -> None: + application = rank().over(window().order_by([col("amount")])) + assert application.kind == WindowFunctionKind.Rank + assert application.function_ref == "inql.functions.rank" + assert application.requires_ordering diff --git a/tests/test_dataset.incn b/tests/test_dataset.incn index 3140a03..b933297 100644 --- a/tests/test_dataset.incn +++ b/tests/test_dataset.incn @@ -28,14 +28,17 @@ from functions import ( mul, posexplode, posexplode_outer, + row_number, str_expr, str_lit, sum, + window, ) from projection_builders import ColumnExprKind, column_expr_kind, column_expr_name from substrait.function_extensions import ( explode_extension_uri, explode_outer_extension_uri, + function_extension_uri, posexplode_extension_uri, posexplode_outer_extension_uri, ) @@ -459,6 +462,10 @@ def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None generated_positional_outer: LazyFrame[Order] = lazy_frame_named_table("orders_generator_dataset").generate( posexplode_outer(col("line_items"), "position", "line_item"), ) + windowed: LazyFrame[Order] = lazy_frame_named_table("orders").with_window_column( + "row_num", + row_number().over(window().order_by([col("id")])), + ) # -- Assert -- assert relation_kind_name(root_rel(projected.to_substrait_plan())) == "ProjectRel", "select should lower through the project boundary shape" @@ -472,6 +479,9 @@ def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None assert plan_has_extension_urn(generated_outer.to_substrait_plan(), explode_outer_extension_uri()), "outer explode should use its relation extension URI" assert plan_has_extension_urn(generated_positional.to_substrait_plan(), posexplode_extension_uri()), "posexplode should use its relation extension URI" assert plan_has_extension_urn(generated_positional_outer.to_substrait_plan(), posexplode_outer_extension_uri()), "posexplode_outer should use its relation extension URI" + assert relation_kind_name(root_rel(windowed.to_substrait_plan())) == "ConsistentPartitionWindowRel", "with_window_column should lower through the window boundary shape" + assert windowed.planned_columns() == ["id", "row_num"], "window projections should append declared output aliases" + assert plan_has_extension_urn(windowed.to_substrait_plan(), function_extension_uri()), "window plans should use the shared function extension URI" def test_lazy_frame__deeper_independent_roots_still_lower_with_stable_shapes() -> None: diff --git a/tests/test_function_registry.incn b/tests/test_function_registry.incn index 424147e..23902fa 100644 --- a/tests/test_function_registry.incn +++ b/tests/test_function_registry.incn @@ -85,13 +85,17 @@ from functions import ( or_, posexplode, posexplode_outer, + dense_rank, registered_substrait_mapped_function_refs, + rank, round, + row_number, str_expr, str_lit, sub, sum, try_cast, + window, ) from function_registry import ( FunctionAliasPolicy, @@ -138,6 +142,7 @@ from substrait.function_extensions import ( CEIL_FUNCTION_ANCHOR, COALESCE_FUNCTION_ANCHOR, COUNT_FUNCTION_ANCHOR, + DENSE_RANK_FUNCTION_ANCHOR, DIVIDE_FUNCTION_ANCHOR, EQUAL_FUNCTION_ANCHOR, FLOOR_FUNCTION_ANCHOR, @@ -165,6 +170,8 @@ from substrait.function_extensions import ( NOT_FUNCTION_ANCHOR, NULLIF_FUNCTION_ANCHOR, OR_FUNCTION_ANCHOR, + RANK_FUNCTION_ANCHOR, + ROW_NUMBER_FUNCTION_ANCHOR, ROUND_FUNCTION_ANCHOR, SUBTRACT_FUNCTION_ANCHOR, SUM_FUNCTION_ANCHOR, @@ -231,12 +238,12 @@ def _local_entry_by_namespace_and_name_or_fail( def _expected_registry_names() -> list[str]: """Return the expected registered public helper names.""" - return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "explode", "explode_outer", "posexplode", "posexplode_outer"] + return ["col", "lit", "sum", "count", "count_expr", "count_distinct", "count_if", "avg", "min", "max", "int_expr", "float_expr", "str_expr", "bool_expr", "add", "mul", "int_lit", "str_lit", "bool_lit", "always_true", "always_false", "eq", "gt", "cast", "try_cast", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "is_not_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "case_when", "in_", "between", "asc", "desc", "asc_nulls_first", "asc_nulls_last", "desc_nulls_first", "desc_nulls_last", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_contains_key", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "explode", "explode_outer", "posexplode", "posexplode_outer", "window", "row_number", "rank", "dense_rank"] def _expected_substrait_mapped_names() -> list[str]: """Return helpers with concrete Substrait extension-function mappings.""" - return ["sum", "count", "count_expr", "avg", "min", "max", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct"] + return ["sum", "count", "count_expr", "avg", "min", "max", "add", "mul", "eq", "gt", "ne", "lt", "lte", "gte", "equal_null", "and_", "or_", "not_", "is_null", "is_not_null", "is_nan", "sub", "div", "mod", "neg", "coalesce", "nullif", "between", "abs", "ceil", "floor", "round", "array", "array_contains", "array_distinct", "array_except", "array_flatten", "array_intersect", "array_join", "array_position", "array_reverse", "array_slice", "array_sort", "array_union", "arrays_overlap", "cardinality", "element_at", "map_entries", "map_extract", "map_from_arrays", "map_keys", "map_values", "named_struct", "row_number", "rank", "dense_rank"] def _exercise_current_public_helpers() -> None: @@ -325,6 +332,10 @@ def _exercise_current_public_helpers() -> None: explode_outer(tags, "tag") posexplode(tags, "position", "tag") posexplode_outer(tags, "position", "tag") + window() + row_number() + rank() + dense_rank() return @@ -641,6 +652,23 @@ def test_function_registry__generator_helpers_are_relation_extensions() -> None: _assert_relation_extension_mapping("posexplode_outer", "posexplode_outer", posexplode_outer_extension_uri()) +def test_function_registry__window_helpers_are_relation_window_functions() -> None: + """Assert window helpers carry relation-aware window-function metadata.""" + # -- Arrange -- + _exercise_current_public_helpers() + window_entry = _entry_by_name_or_fail("window") + row_number_entry = _entry_by_name_or_fail("row_number") + + # -- Act / Assert -- + assert window_entry.function_class == FunctionClass.Window, "window spec builder should be classified as window metadata" + assert window_entry.substrait.kind == SubstraitMappingKind.StructuralFunction, "window spec builder should be structural metadata" + assert window_entry.substrait.function_name == "window_spec", "window spec builder should name the window-spec context" + assert row_number_entry.function_class == FunctionClass.Window, "ranking helpers should be classified as window functions" + _assert_extension_mapping("row_number", "row_number", ROW_NUMBER_FUNCTION_ANCHOR) + _assert_extension_mapping("rank", "rank", RANK_FUNCTION_ANCHOR) + _assert_extension_mapping("dense_rank", "dense_rank", DENSE_RANK_FUNCTION_ANCHOR) + + def test_function_registry__ordering_helpers_are_contextual_sort_fields() -> None: """Assert RFC 015 ordering helpers are modeled as sort-field context helpers.""" # -- Arrange -- diff --git a/tests/test_prism.incn b/tests/test_prism.incn index fc1f097..16b8564 100644 --- a/tests/test_prism.incn +++ b/tests/test_prism.incn @@ -1,12 +1,13 @@ """Internal Prism engine tests against the shared-store cursor substrate.""" -from functions import always_false, always_true, col, count, count_expr, explode, lit, mul, sum +from functions import always_false, always_true, col, count, count_expr, explode, lit, mul, row_number, sum, window from prism import ( PrismCursor, prism_cursor_apply_filter, prism_cursor_apply_generate, prism_cursor_apply_limit, prism_cursor_apply_select, + prism_cursor_apply_with_window_column, prism_cursor_authored_node_count, prism_cursor_named_table, prism_cursor_rewrite_applied_rule_count, @@ -21,7 +22,8 @@ from prism import ( prism_cursor_tip_origin_id, prism_cursors_share_store, ) -from substrait.inspect import plan_contains_relation_kind, relation_kind_name, root_rel +from substrait.function_extensions import function_extension_uri +from substrait.inspect import plan_contains_relation_kind, plan_has_extension_urn, relation_kind_name, root_rel from substrait.plans import plan_encoded_len from substrait.schema_registry import register_named_table_schema from substrait.schema import RowColumnSpec, SubstraitPrimitiveKind @@ -230,6 +232,10 @@ def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: generated: PrismCursor[Order] = prism_cursor_named_table(str("orders_generator_prism")).generate( explode(col("line_items"), "line_item"), ) + windowed: PrismCursor[Order] = prism_cursor_named_table(str("orders")).with_window_column( + "row_num", + row_number().over(window().order_by([col("id")])), + ) # -- Assert -- assert prism_cursor_tip_kind_name(projected) == str("Project"), "select should append a native project node" @@ -239,7 +245,24 @@ def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: assert prism_cursor_tip_kind_name(limited) == str("Limit"), "limit should append a native limit node" assert prism_cursor_tip_kind_name(exploded) == str("Explode"), "explode should append a native explode node" assert prism_cursor_tip_kind_name(generated) == str("Generate"), "generate should append a native generator node" + assert prism_cursor_tip_kind_name(windowed) == str("Window"), "with_window_column should append a native window node" assert prism_cursor_output_columns(generated) == ["id", "line_items", "line_item"], "generate should append declared output aliases" + assert prism_cursor_output_columns(windowed) == ["id", "row_num"], "window projections should append declared output aliases" + + +def test_prism__window_column_lowers_through_substrait_boundary() -> None: + # -- Arrange -- + _register_projection_test_schema(str("orders")) + base: PrismCursor[Order] = prism_cursor_named_table(str("orders")) + + # -- Act -- + windowed = prism_cursor_apply_with_window_column(base, "row_num", row_number().over(window().order_by([col("id")]))) + plan = windowed.to_substrait_plan() + + # -- Assert -- + assert prism_cursor_tip_kind_name(windowed) == str("Window"), "window helper should create a Prism window node" + assert relation_kind_name(root_rel(plan)) == str("ConsistentPartitionWindowRel"), "window Prism node should lower to a window relation" + assert plan_has_extension_urn(plan, function_extension_uri()), "window plans should declare the shared function extension URI" def test_prism__rewrite_eliminates_filter_true_by_default() -> None: diff --git a/tests/test_substrait_plan.incn b/tests/test_substrait_plan.incn index 04bb902..8a44395 100644 --- a/tests/test_substrait_plan.incn +++ b/tests/test_substrait_plan.incn @@ -32,6 +32,7 @@ from functions import ( count_distinct, count_expr, count_if, + dense_rank, desc, div, eq, @@ -63,12 +64,15 @@ from functions import ( not_, nullif, or_, + rank, round, + row_number, sub, sum, try_cast, cardinality, element_at, + window, ) from projection_builders import ColumnExpr, with_column_assignment from substrait.errors import SubstraitLoweringErrorKind @@ -100,6 +104,9 @@ from substrait.inspect import ( sort_field_count, sort_field_direction_name, sort_field_expr_index, + window_function_names, + window_partition_count, + window_sort_count, ) from substrait.plans import ( plan_encoded_len, @@ -131,9 +138,12 @@ from substrait.relations import ( set_rel_of_kind, sort_rel_of_columns, try_aggregate_rel_of_columns, + try_window_rel_of_columns, + window_rel_of_columns, ) from substrait.schema_registry import named_table_columns, register_named_table_schema from substrait.schema import RowColumnSpec, SubstraitPrimitiveKind +from window_builders import WindowProjection, window_projection from substrait.conformance import ( ConformanceCapabilityTags, ConformancePortability, @@ -202,6 +212,32 @@ def _register_orders_schema() -> None: register_named_table_schema("orders", [RowColumnSpec(name="id", kind=SubstraitPrimitiveKind.I64, nullable=false)]) +def _register_window_orders_schema() -> None: + register_named_table_schema( + "orders_window", + [RowColumnSpec(name="customer_id", kind=SubstraitPrimitiveKind.String, nullable=false), RowColumnSpec( + name="amount", + kind=SubstraitPrimitiveKind.I64, + nullable=false, + )], + ) + + +def _row_number_projection() -> WindowProjection: + spec = window().partition_by([col("customer_id")]).order_by([desc(col("amount"))]) + return window_projection("row_num", row_number().over(spec)) + + +def _rank_projection() -> WindowProjection: + spec = window().partition_by([col("customer_id")]).order_by([desc(col("amount"))]) + return window_projection("amount_rank", rank().over(spec)) + + +def _dense_rank_projection() -> WindowProjection: + spec = window().partition_by([col("customer_id")]).order_by([desc(col("amount"))]) + return window_projection("dense_amount_rank", dense_rank().over(spec)) + + def _register_fixture_schema(table_name: str) -> None: register_named_table_schema( table_name, @@ -541,6 +577,61 @@ def test_plan__aggregate_rel_rejects_invalid_modifier_shapes() -> None: assert ordered_err.message.contains("ordered aggregate input"), "ordered aggregate diagnostic should identify the unsupported modifier" +def _assert_window_projection_lowers(projection: WindowProjection, expected_function_name: str) -> None: + """Assert one window projection lowers to a concrete Substrait window relation.""" + _register_window_orders_schema() + base = read_named_table_rel("orders_window") + windowed = window_rel_of_columns(base, ["customer_id", "amount"], [projection]) + plan = plan_from_root_relation(windowed, ["customer_id", "amount", projection.output_name]) + + assert relation_kind_name(windowed) == "ConsistentPartitionWindowRel", "window lowering should emit the Substrait window relation" + assert plan_contains_relation_kind(plan, "ConsistentPartitionWindowRel"), "window plans should preserve the window relation shape" + assert plan_has_extension_urn(plan, function_extension_uri()), "window function plans should register the shared function extension URN" + assert window_function_names(windowed) == [expected_function_name], "window relation should carry the registered window function" + assert window_partition_count(windowed) == 1, "window relation should lower explicit partition expressions" + assert window_sort_count(windowed) == 1, "ranking window relation should lower explicit ordering" + + +def test_plan__row_number_window_rel_lowers_to_substrait() -> None: + # -- Arrange / Act / Assert -- + _assert_window_projection_lowers(_row_number_projection(), "row_number") + + +def test_plan__rank_window_rel_lowers_to_substrait() -> None: + # -- Arrange / Act / Assert -- + _assert_window_projection_lowers(_rank_projection(), "rank") + + +def test_plan__dense_rank_window_rel_lowers_to_substrait() -> None: + # -- Arrange / Act / Assert -- + _assert_window_projection_lowers(_dense_rank_projection(), "dense_rank") + + +def test_plan__ranking_window_rel_rejects_missing_ordering() -> None: + # -- Arrange -- + _register_window_orders_schema() + base = read_named_table_rel("orders_window") + unordered = window_projection("row_num", row_number().over(window().partition_by([col("customer_id")]))) + + # -- Act -- + result = try_window_rel_of_columns(base, ["customer_id", "amount"], [unordered]) + + # -- Assert -- + assert_is_err(result, "ranking window helpers should require explicit ordering") + + +def test_plan__window_rel_rejects_multiple_projections_until_partition_grouping_lands() -> None: + # -- Arrange -- + _register_window_orders_schema() + base = read_named_table_rel("orders_window") + + # -- Act -- + result = try_window_rel_of_columns(base, ["customer_id", "amount"], [_row_number_projection(), _rank_projection()]) + + # -- Assert -- + assert_is_err(result, "current window relation lowering should reject multiple projections explicitly") + + def test_plan__set_rel_uses_operation_enum() -> None: # -- Arrange -- left = read_named_table_rel("orders_current") diff --git a/tests/test_window_functions.incn b/tests/test_window_functions.incn new file mode 100644 index 0000000..17fb30c --- /dev/null +++ b/tests/test_window_functions.incn @@ -0,0 +1,47 @@ +"""Tests for RFC 019 window specification and ranking helpers.""" + +from functions import col, dense_rank, rank, row_number, window +from projection_builders import column_expr_name +from window_builders import WindowFunctionKind, window_output_columns, window_projection + + +def test_window_builders__spec_preserves_partition_and_order_columns() -> None: + # -- Arrange / Act -- + spec = window().partition_by([col("customer_id")]).order_by([col("amount")]) + + # -- Assert -- + assert len(spec.partition_columns) == 1, "window partition should record explicit partition expressions" + assert column_expr_name(spec.partition_columns[0]) == "customer_id", "partition expression should preserve column refs" + assert len(spec.sort_columns) == 1, "window ordering should record explicit sort expressions" + assert column_expr_name(spec.sort_columns[0]) == "amount", "sort expression should preserve column refs" + + +def test_window_builders__ranking_helpers_return_unplaced_calls() -> None: + # -- Arrange -- + spec = window().order_by([col("amount")]) + + # -- Act -- + row_number_app = row_number().over(spec) + rank_app = rank().over(spec) + dense_rank_app = dense_rank().over(spec) + + # -- Assert -- + assert row_number_app.kind == WindowFunctionKind.RowNumber, "row_number should keep typed window identity" + assert rank_app.kind == WindowFunctionKind.Rank, "rank should keep typed window identity" + assert dense_rank_app.kind == WindowFunctionKind.DenseRank, "dense_rank should keep typed window identity" + assert row_number_app.requires_ordering, "ranking helpers should require explicit window ordering" + assert rank_app.function_ref == "inql.functions.rank", "rank should derive stable registry identity" + assert dense_rank_app.canonical_name == "dense_rank", "dense_rank should expose its canonical name" + + +def test_window_builders__output_columns_use_add_or_replace_alias_semantics() -> None: + # -- Arrange -- + spec = window().order_by([col("amount")]) + + # -- Act -- + appended = window_output_columns(["id", "amount"], [window_projection("row_num", row_number().over(spec))]) + replaced = window_output_columns(["id", "rank"], [window_projection("rank", rank().over(spec))]) + + # -- Assert -- + assert appended == ["id", "amount", "row_num"], "new window aliases should append to input columns" + assert replaced == ["id", "rank"], "existing window aliases should replace in place without duplicating names"