diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index af4aa4aee3..49fc4b43f6 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -13,6 +13,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, NO_SIZE, _jit_options, _vectorized, @@ -96,6 +97,7 @@ def impl(*inputs_and_core_shapes): NO_SIZE, NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, ) return impl diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index d1740e78ae..faa889a970 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -29,6 +29,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, NO_SIZE, _jit_options, _vectorized, @@ -140,6 +141,38 @@ def scalar_in_place_fn_Minimum(op, idx, res, arr): ] +# Scalar reduce ops that ``accumulate_into_slice`` (and thus fused reductions / +# indexed inc) supports, paired with the in-place form used on the output slice. +# Augmented assignment is used where Numba supports it directly on 0-d arrays +# (``out`` is the core output slice, which is 0-d for scalar cores); Maximum and +# Minimum lack an augmented operator so they use a ufunc + ellipsis-set instead. +_SLICE_ACCUMULATE = { + Add: "{out} += {inner}", + Mul: "{out} *= {inner}", + AND: "{out} &= {inner}", + OR: "{out} |= {inner}", + XOR: "{out} ^= {inner}", + Maximum: "{out}[...] = np.maximum({out}, {inner})", + Minimum: "{out}[...] = np.minimum({out}, {inner})", +} + + +def accumulate_into_slice(scalar_op, out: str, inner: str) -> list[str]: + """In-place accumulation lines for a (possibly 0-d) output slice ``out``. + + Unlike ``scalar_in_place_fn`` (which indexes ``res[idx]`` and breaks on 0-d + cores), this operates on the whole slice so it is valid for both scalar + (0-d) and array cores. Used by both fused reductions and indexed ``inc``. + """ + try: + template = _SLICE_ACCUMULATE[type(scalar_op)] + except KeyError: + raise NotImplementedError( + f"No fused-reduction accumulation for scalar op {scalar_op}" + ) + return [template.format(out=out, inner=inner)] + + @intrinsic def _address_as_void_pointer(typingctx, src): """Returns a void pointer from a given memory address.""" @@ -695,6 +728,7 @@ def impl(*inputs): NO_SIZE, NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, ) return impl @@ -715,6 +749,86 @@ def impl(*inputs): return elemwise, elemwise_key +def _reduce_identity(identity, acc_dtype): + """Coerce a reduction identity to a concrete numpy scalar in ``acc_dtype``. + + Non-finite identities (``±inf`` for max/min) become the dtype's bounds for + integer accumulators, mirroring ``create_multiaxis_reducer``. + """ + acc_dtype = np.dtype(acc_dtype) + if acc_dtype.kind in "ui" and not np.isfinite(identity): + identity = ( + np.iinfo(acc_dtype).max + if np.isposinf(identity) + else np.iinfo(acc_dtype).min + ) + return acc_dtype.type(identity) + + +def _build_reduce_impl_src(nout, post_specs): + """Build the ``@overload`` impl source for a fused op with reduction outputs. + + Calls ``_vectorized`` (which produces a keepdims, size-1-on-reduced-axes + buffer in ``acc_dtype``) then, per reduction output, squeezes the reduced + axes back out and casts to the true output dtype. ``post_specs[i]`` is + ``None`` for a passthrough output, else ``(kept_axes, out_dtype, cast_needed)``. + """ + code: list[str | CODE_TOKEN] = [ + "def fused_elemwise_impl(*outer_inputs):", + CODE_TOKEN.INDENT, + "raw = _vectorized(", + CODE_TOKEN.INDENT, + "core_op_fn,", + "input_bc_patterns_enc,", + "output_bc_patterns_enc,", + "output_dtypes_enc,", + "inplace_pattern_enc,", + "True,", + "(),", + "outer_inputs,", + "core_output_shapes,", + "NO_SIZE,", + "indexed_inputs_enc,", + "indexed_outputs_enc,", + "reduce_outputs_enc,", + CODE_TOKEN.DEDENT, + ")", + ] + + def src_access(i): + return "raw" if nout == 1 else f"raw[{i}]" + + out_syms = [] + for i in range(nout): + sym = f"o{i}" + out_syms.append(sym) + src = src_access(i) + spec = post_specs[i] + if spec is None: + code.append(f"{sym} = {src}") + continue + kept_axes, out_dtype, cast_needed = spec + np_dtype = "bool_" if out_dtype == "bool" else out_dtype + if not kept_axes: + # Full reduction → 0-d array of the single accumulated value. + code.append(f"{sym} = np.array({src}.ravel()[0], dtype=np.{np_dtype})") + else: + shape_expr = create_tuple_string( + tuple(f"{src}.shape[{k}]" for k in kept_axes) + ) + expr = f"{src}.reshape({shape_expr})" + if cast_needed: + expr = f"{expr}.astype(np.{np_dtype})" + code.append(f"{sym} = {expr}") + + if nout == 1: + code.append(f"return {out_syms[0]}") + else: + code.append(f"return {create_tuple_string(tuple(out_syms))}") + code.append(CODE_TOKEN.DEDENT) + return build_source_code(code) + + @register_funcify_and_cache_key(IndexedElemwise) def numba_funcify_IndexedElemwise(op, node, **kwargs): """Generate fused Elemwise Numba code with indexed reads and updates. @@ -739,6 +853,7 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): n_indices = len(indexed_inputs) nin_elemwise = len(elemwise_node.inputs) nout = len(elemwise_node.outputs) + reduced_outputs = op.reduced_outputs or ((None,) * nout) inc_outputs = frozenset( out_idx @@ -747,14 +862,60 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): for out_idx in entry[0] if entry[2] == "inc" ) + reduced_idxs = frozenset(i for i, r in enumerate(reduced_outputs) if r is not None) + assert not (inc_outputs & reduced_idxs), ( + "An output cannot be both an indexed write and a reduction" + ) + + # accum_fns bake per-output in-place accumulation into store_core_outputs: + # indexed `inc` writes accumulate with +=; reductions use their scalar op + # (Add/Mul/Maximum/...). Both go through accumulate_into_slice so they are + # valid on 0-d (scalar core) slices. Disjoint output indices. + accum_fns: dict = {} + for out_idx in inc_outputs: + accum_fns[out_idx] = lambda out, inner: [f"{out} += {inner}"] + for i in reduced_idxs: + accum_fns[i] = ( + lambda out, inner, _op=reduced_outputs[i][0]: accumulate_into_slice( + _op, out, inner + ) + ) core_op_fn = store_core_outputs( - scalar_op_fn, nin=nin_elemwise, nout=nout, inc_outputs=inc_outputs + scalar_op_fn, nin=nin_elemwise, nout=nout, accum_fns=accum_fns ) input_bc_patterns = tuple(inp.type.broadcastable for inp in elemwise_node.inputs) - output_bc_patterns = tuple(out.type.broadcastable for out in elemwise_node.outputs) - output_dtypes = tuple(out.type.dtype for out in node.outputs) + + # Reduction outputs keep their reduced axes as bc=True (size-1, keepdims) and + # are accumulated in acc_dtype; post_specs squeezes + casts them back to the + # true CAReduce output below. + output_bc_patterns_list = [] + output_dtypes_list = [] + reduce_identities = [] + post_specs: list = [] + has_reductions = bool(reduced_idxs) + for i, inner_out in enumerate(elemwise_node.outputs): + spec = reduced_outputs[i] + if spec is None: + output_bc_patterns_list.append(inner_out.type.broadcastable) + output_dtypes_list.append(node.outputs[i].type.dtype) + post_specs.append(None) + continue + _reduce_op, axes, identity, acc_dtype = spec + bc = list(inner_out.type.broadcastable) + for ax in axes: + bc[ax] = True + output_bc_patterns_list.append(tuple(bc)) + output_dtypes_list.append(str(np.dtype(acc_dtype))) + reduce_identities.append((i, _reduce_identity(identity, acc_dtype))) + kept_axes = tuple(d for d in range(len(bc)) if d not in axes) + out_dtype = node.outputs[i].type.dtype + cast_needed = np.dtype(acc_dtype) != np.dtype(out_dtype) + post_specs.append((kept_axes, out_dtype, cast_needed)) + + output_bc_patterns = tuple(output_bc_patterns_list) + output_dtypes = tuple(output_dtypes_list) inplace_pattern = tuple(elemwise_node.op.inplace_pattern.items()) core_output_shapes = tuple(() for _ in range(nout)) @@ -768,36 +929,71 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): inplace_pattern_enc = encode_literals(inplace_pattern) indexed_inputs_enc = encode_literals((indexed_inputs, idx_broadcastable)) indexed_outputs_enc = encode_literals(indexed_outputs) + reduce_outputs_enc = encode_literals(tuple(reduce_identities)) def indexed_elemwise_fn(*outer_inputs): - raise NotImplementedError( - "IndexedElemwise cannot be evaluated in Python (non-JIT) mode." - ) + # Python-mode fallback (e.g. Numba's ``eval_python_only`` path, which + # runs the funcified function without JIT). Evaluate the inner fgraph + # faithfully, exactly like ``OpFromGraph.perform`` — so the fused op + # still produces valid results when executed directly in Python. The + # JIT path never reaches this body: the ``@overload`` below supplies the + # vectorized loop implementation. + res = op.fn(*outer_inputs) + return res[0] if len(res) == 1 else tuple(res) + + if not has_reductions: + + @overload(indexed_elemwise_fn, jit_options=_jit_options) + def ov_indexed_elemwise_fn(*outer_inputs): + def impl(*outer_inputs): + return _vectorized( + core_op_fn, + input_bc_patterns_enc, + output_bc_patterns_enc, + output_dtypes_enc, + inplace_pattern_enc, + True, # allow_core_scalar + (), # constant_inputs + outer_inputs, + core_output_shapes, + NO_SIZE, + indexed_inputs_enc, + indexed_outputs_enc, + reduce_outputs_enc, + ) - @overload(indexed_elemwise_fn, jit_options=_jit_options) - def ov_indexed_elemwise_fn(*outer_inputs): - def impl(*outer_inputs): - return _vectorized( - core_op_fn, - input_bc_patterns_enc, - output_bc_patterns_enc, - output_dtypes_enc, - inplace_pattern_enc, - True, # allow_core_scalar - (), # constant_inputs - outer_inputs, - core_output_shapes, - NO_SIZE, - indexed_inputs_enc, - indexed_outputs_enc, - ) + return impl + else: + impl_src = _build_reduce_impl_src(nout, post_specs) + impl_fn = compile_numba_function_src( + impl_src, + "fused_elemwise_impl", + { + **globals(), + "core_op_fn": core_op_fn, + "input_bc_patterns_enc": input_bc_patterns_enc, + "output_bc_patterns_enc": output_bc_patterns_enc, + "output_dtypes_enc": output_dtypes_enc, + "inplace_pattern_enc": inplace_pattern_enc, + "core_output_shapes": core_output_shapes, + "indexed_inputs_enc": indexed_inputs_enc, + "indexed_outputs_enc": indexed_outputs_enc, + "reduce_outputs_enc": reduce_outputs_enc, + }, + ) - return impl + @overload(indexed_elemwise_fn, jit_options=_jit_options) + def ov_indexed_elemwise_fn(*outer_inputs): + return impl_fn - cache_version = 1 + cache_version = 2 if scalar_cache_key is None: key = None else: + reduced_key = tuple( + (type(r[0]).__name__, r[1], str(np.dtype(r[3]))) if r is not None else None + for r in reduced_outputs + ) key = str( ( type(op), @@ -808,6 +1004,7 @@ def impl(*outer_inputs): indexed_inputs, idx_broadcastable, indexed_outputs, + reduced_key, scalar_cache_key, ) ) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 28ade0bf3a..2f7f8b3700 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -23,6 +23,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, NO_SIZE, _jit_options, _vectorized, @@ -490,6 +491,7 @@ def impl(core_shape, rng, size, *dist_params): else numba_ndarray.to_fixed_tuple(size, size_len), NO_INDEXED_INPUTS, NO_INDEXED_OUTPUTS, + NO_REDUCE_OUTPUTS, ) return rng, draws diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index c5dc338da8..4a2d151911 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -3,7 +3,6 @@ import base64 import pickle from collections.abc import Callable, Sequence -from textwrap import indent from typing import Any import numba @@ -17,6 +16,7 @@ from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.string_codegen import CODE_TOKEN, build_source_code def encode_literals(literals: Sequence) -> str: @@ -24,7 +24,10 @@ def encode_literals(literals: Sequence) -> str: def store_core_outputs( - core_op_fn: Callable, nin: int, nout: int, inc_outputs: frozenset = frozenset() + core_op_fn: Callable, + nin: int, + nout: int, + accum_fns: dict[int, Callable[[str, str], Sequence[str | CODE_TOKEN]]] | None = None, ) -> Callable: """Create a Numba function that wraps a core function and stores its vectorized outputs. @@ -35,14 +38,20 @@ def store_core_outputs( def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): to0, to1, ..., ton = core_op_fn(i0, i1, ..., in) o0[...] = to0 # direct outputs - o1 += to1 # inc outputs (in-place add works for 0d and Nd) + o1[...] += to1 # accumulating outputs (reduce / indexed-inc) ... - ``inc_outputs`` lists output indices that use ``+=`` instead of ``=``. + ``accum_fns`` maps an output index to a callable ``(out_sym, inner_sym) -> + lines`` producing the in-place accumulation code for that output (e.g. + ``["o1[...] += t1"]`` for a sum reduction, or the multi-line conditional for + a max reduction). Outputs absent from ``accum_fns`` are stored with ``=``. + Both reductions and indexed ``inc`` writes go through this mechanism. """ if getattr(core_op_fn, "handles_out", False): return core_op_fn + accum_fns = accum_fns or {} + inputs = [f"i{i}" for i in range(nin)] outputs = [f"o{i}" for i in range(nout)] inner_outputs = [f"t{output}" for output in outputs] @@ -50,19 +59,22 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): inp_signature = ", ".join(inputs) out_signature = ", ".join(outputs) inner_out_signature = ", ".join(inner_outputs) - store_outputs = "\n".join( - f"{output} += {inner_output}" - if i in inc_outputs - else f"{output}[...] = {inner_output}" - for i, (output, inner_output) in enumerate( - zip(outputs, inner_outputs, strict=True) - ) - ) - func_src = f""" -def store_core_outputs({inp_signature}, {out_signature}): - {inner_out_signature} = core_op_fn({inp_signature}) -{indent(store_outputs, " " * 4)} -""" + + code: list[str | CODE_TOKEN] = [ + f"def store_core_outputs({inp_signature}, {out_signature}):", + CODE_TOKEN.INDENT, + f"{inner_out_signature} = core_op_fn({inp_signature})", + ] + for i, (output, inner_output) in enumerate( + zip(outputs, inner_outputs, strict=True) + ): + if i in accum_fns: + code.extend(accum_fns[i](output, inner_output)) + else: + code.append(f"{output}[...] = {inner_output}") + code.append(CODE_TOKEN.DEDENT) + + func_src = build_source_code(code) global_env = {"core_op_fn": core_op_fn} func = compile_numba_function_src( @@ -249,6 +261,7 @@ def _codegen_return_outputs( NO_INDEXED_INPUTS = encode_literals(((), ())) NO_INDEXED_OUTPUTS = encode_literals(()) +NO_REDUCE_OUTPUTS = encode_literals(()) NO_SIZE = None @@ -355,11 +368,17 @@ def make_outputs( input_types: tuple[Any, ...], output_core_shapes: tuple, update_outputs: dict | None = None, + reduce_identities: dict | None = None, ) -> tuple[list[ir.Value], list[types.Array]]: """Allocate output arrays for vectorized loop. ``update_outputs`` maps ``{output_idx: (array, array_type)}`` for outputs that reuse an indexed-write target buffer instead of being freshly allocated. + + ``reduce_identities`` maps ``{output_idx: identity_value}`` for reduction + outputs. Such outputs are freshly allocated (size 1 on the reduced axes, via + their ``bc=True`` pattern) and pre-filled with the reduction identity so the + accumulating store in the loop reduces into them correctly. """ output_arrays = [] output_arry_types = [] @@ -389,6 +408,22 @@ def make_outputs( ] shape = batch_shape + core_shape array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape) + if reduce_identities is not None and i in reduce_identities: + # Pre-fill the freshly allocated (C-contiguous) buffer with the + # reduction identity. A flat scan over every element is valid + # regardless of which axes are reduced, and seeds each kept-axis + # accumulator cell (size-1 reduced axes included). + nitems = ir.IntType(64)(1) + for dim_len in shape: + nitems = builder.mul(nitems, dim_len) + ident = ctx.get_constant(dtype, reduce_identities[i]) + # bool is an i1 value but stored as i8 in arrays; widen to the + # buffer's element type so the store types match. + elem_ty = array.data.type.pointee + if ident.type != elem_ty: + ident = builder.zext(ident, elem_ty) + with cgutils.for_range(builder, nitems) as loop: + builder.store(ident, builder.gep(array.data, [loop.index])) output_arrays.append(array) # If there is no inplace operation, we know that all output arrays @@ -449,21 +484,13 @@ def _wrap_negative_index(idx_val, dim_size, signed): wrapped = builder.add(idx_val, dim_size) return builder.select(is_neg, wrapped, idx_val) - # Setup loops and initialize accumulators for outputs - # This part corresponds to opening the loops + # Open one loop per iteration dimension. Reduction outputs need no special + # setup here: they carry ``bc=True`` on the reduced axes, so the write_idx + # logic below points every iteration over a reduced axis at memory index 0 + # (the same cell), and the accumulating store reduces into it. loop_stack = [] loops = [] - output_accumulator: list[tuple[Any | None, int | None]] = [(None, None)] * n_outputs - for dim, length in enumerate(iter_shape): - # Find outputs that only have accumulations left - for out in range(n_outputs): - if output_accumulator[out][0] is not None: - continue - if all(output_bc[out][dim:]): - value = outputs[out][0].type.pointee(0) - accu = cgutils.alloca_once_value(builder, value) - output_accumulator[out] = (accu, dim) - + for length in iter_shape: loop = cgutils.for_range(builder, length) loop_stack.append(loop) loops.append(loop.__enter__()) @@ -736,6 +763,7 @@ def _vectorized( size_type, indexed_inputs, indexed_outputs, + reduce_outputs, ): """Vectorized intrinsic with optional indirect indexing for reads and writes. @@ -751,7 +779,12 @@ def _vectorized( ``((out_0, out_1), mode)`` means that index updates outputs out_0 and out_1 with *mode* ``"set"`` or ``"inc"``. - For non-indexed calls, both are ``()``. + ``reduce_outputs`` lists ``(output_idx, identity)`` pairs for reduction + outputs. Such an output carries ``bc=True`` on its reduced axes; the buffer + is allocated size 1 there, pre-filled with ``identity``, and the per-iteration + store (baked into ``core_func`` via ``store_core_outputs``) accumulates into it. + + For non-indexed/non-reducing calls, these are ``()``. """ arg_types = [ core_func, @@ -766,12 +799,14 @@ def _vectorized( size_type, indexed_inputs, indexed_outputs, + reduce_outputs, ] input_bc_patterns = _decode_literal(input_bc_patterns, "input_bc_patterns") output_bc_patterns = _decode_literal(output_bc_patterns, "output_bc_patterns") output_dtypes = _decode_literal(output_dtypes, "output_dtypes") inplace_pattern = _decode_literal(inplace_pattern, "inplace_pattern") + reduce_identities = dict(_decode_literal(reduce_outputs, "reduce_outputs")) indexed_inputs, idx_broadcastable = _decode_literal( indexed_inputs, "indexed_inputs" ) @@ -919,6 +954,7 @@ def codegen(ctx, builder, sig, args): size, _, _, + _, ] = args constant_inputs = cgutils.unpack_tuple(builder, constant_inputs) @@ -1055,6 +1091,7 @@ def codegen(ctx, builder, sig, args): source_input_types, output_core_shapes, update_outputs=update_outputs_dict, + reduce_identities=reduce_identities, ) core_signature = typingctx.resolve_function_type( diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index b3ff3828b3..657dfcccaf 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -14,8 +14,8 @@ from pytensor.graph.rewriting.unify import OpPattern from pytensor.graph.utils import InconsistencyError from pytensor.printing import op_debug_information -from pytensor.scalar.basic import Composite -from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.scalar.basic import AND, OR, XOR, Add, Composite, Maximum, Minimum, Mul +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer from pytensor.tensor.shape import Reshape, shape_padright from pytensor.tensor.subtensor import ( @@ -27,6 +27,12 @@ from pytensor.tensor.variable import TensorVariable +# CAReduce scalar ops whose reduction the Numba backend can fuse into the loop +# (those for which the codegen has an in-place accumulation: see +# ``accumulate_into_slice`` in the numba elemwise dispatch). +_REDUCE_SCALAR_OPS = (Add, Mul, Maximum, Minimum, AND, OR, XOR) + + def _view_root(view_i, var): """Follow the destroy-handler view chain to the underlying buffer. @@ -250,11 +256,40 @@ class IndexedElemwise(OpFromGraph): Examples:: tgt[idx] += exp(x) → indexed_outputs=[((0,), 0, "inc")] + + reduced_outputs : tuple of ((scalar_op, axes, identity, acc_dtype) | None) + One entry per (inner Elemwise / outer op) output position. + ``None`` if the output is not a reduction. + Otherwise the output is the result of reducing the inner Elemwise output + with ``CAReduce(scalar_op)`` over ``axes``: + + - ``scalar_op``: the commutative/associative binary scalar op of the + reduction (e.g. ``add`` for sum, ``mul`` for prod, ``maximum`` for max). + - ``axes``: tuple of reduced batch axes (in the inner Elemwise's dim space). + - ``identity``: the reduction identity, used to seed the accumulator buffer. + - ``acc_dtype``: dtype the accumulation is carried out in (the Numba loop + accumulates in this dtype, then casts to the output dtype). + + The inner fgraph still holds a faithful ``CAReduce(Elemwise(...))`` so + non-Numba backends evaluate correctly via ``OpFromGraph.perform``; only + the Numba backend reads this spec to fuse the reduction into the loop. + + Examples:: + + sum(exp(x)) → reduced_outputs=[(add, (0,), 0.0, "float64")] """ - def __init__(self, *args, indexed_inputs=(), indexed_outputs=(), **kwargs): + def __init__( + self, + *args, + indexed_inputs=(), + indexed_outputs=(), + reduced_outputs=(), + **kwargs, + ): self.indexed_inputs = indexed_inputs self.indexed_outputs = indexed_outputs + self.reduced_outputs = reduced_outputs # A read buffer can occupy multiple input slots (e.g. read through # several indices); construct_nominal_fgraph dedupes those to one # nominal, leaving the extra slots as unused NominalVariables, which is @@ -264,10 +299,19 @@ def __init__(self, *args, indexed_inputs=(), indexed_outputs=(), **kwargs): super().__init__(*args, on_unused_input="ignore", accept_inplace=True, **kwargs) def __str__(self): + elemwise_str = "Elemwise" for node in self.fgraph.apply_nodes: if isinstance(node.op, Elemwise): - return f"IndexedElemwise{{{node.op!s}}}" - return "IndexedElemwise" + elemwise_str = str(node.op) + break + reductions = [ + f"{type(spec[0]).__name__.lower()}@{spec[1]}" + for spec in self.reduced_outputs + if spec is not None + ] + if reductions: + return f"IndexedElemwise{{{elemwise_str}, reduce[{', '.join(reductions)}]}}" + return f"IndexedElemwise{{{elemwise_str}}}" @op_debug_information.register(IndexedElemwise) @@ -317,6 +361,16 @@ def _op_debug_information_IndexedElemwise(op, node): f"indexed {mode} ({buf_label}, {idx_label})" ) + # Annotate reduced outputs + for out_idx, spec in enumerate(op.reduced_outputs): + if spec is None or out_idx >= len(node.outputs): + continue + scalar_op, axes, _identity, acc_dtype = spec + info[node.outputs[out_idx]] = ( + f"reduced[{type(scalar_op).__name__.lower()}] " + f"over axes {axes} acc={acc_dtype}" + ) + return {node: info} @@ -574,7 +628,25 @@ def apply(self, fgraph): idx_groups[idx_axis_pair][1].append(out_idx) write_targets[out_idx] = client_node - if not idx_groups: + # Find reductions to fuse: an Elemwise output whose sole client is an + # eligible CAReduce. Outputs that are write targets are excluded (an + # output can't be both an indexed write and a reduction). Outputs + # with extra (non-reduce) clients are handled by duplication below. + reduced_outputs = {} # out_idx -> car_node + for out_idx, out in enumerate(node.outputs): + if out_idx in write_targets: + continue + car_clients = [ + c + for c, _ in fgraph.clients[out] + if isinstance(c.op, CAReduce) + and isinstance(c.op.scalar_op, _REDUCE_SCALAR_OPS) + ] + if len(car_clients) != 1: + continue + reduced_outputs[out_idx] = car_clients[0] + + if not idx_groups and not reduced_outputs: continue if must_transpose_write_axes: @@ -658,6 +730,38 @@ def _has_non_write_clients(out_idx): worklist.append(new_node) continue + # If a reduced output also feeds non-reduce consumers, duplicate it via + # Composite so the reduction consumes the duplicate while the original + # stays materialised for the other consumers. We still fuse the + # reduction loop, even if the full output must also be produced. + def _has_non_reduce_clients(out_idx): + car_node = reduced_outputs[out_idx] + return any( + c is not car_node for c, _ in fgraph.clients[node.outputs[out_idx]] + ) + + if reduce_and_direct_use_outs := { + out_idx + for out_idx in reduced_outputs + if _has_non_reduce_clients(out_idx) + }: + new_node, dup_map = self._duplicate_multi_client_outputs( + node, reduce_and_direct_use_outs + ) + replacements = list( + zip(node.outputs, new_node.outputs[: len(node.outputs)]) + ) + for out_idx, dup_idx in dup_map.items(): + car_node = reduced_outputs[out_idx] + new_reduced = car_node.op(new_node.outputs[dup_idx]) + replacements.append((car_node.outputs[0], new_reduced)) + fgraph.replace_all( + replacements, + reason="fuse_reduce_and_direct_outputs", + ) + worklist.append(new_node) + continue + idx_vars = [idx for idx, _axis in idx_groups] fgraph_destroy_map = { @@ -715,6 +819,41 @@ def _has_non_write_clients(out_idx): fgraph_outputs[out_idx] = write_out fgraph_destroy_map[out_idx] = [target_pos] + # Inner fgraph reduced outputs: wrap the Elemwise output in the real + # CAReduce so non-Numba backends compute it faithfully via perform. + # reduced_spec carries the (scalar_op, axes, identity, acc_dtype) the + # Numba backend reads to fuse the reduction into the loop instead. + reduced_spec_by_idx = {} + for out_idx, car_node in sorted(reduced_outputs.items()): + car_op = car_node.op + ndim = node.outputs[out_idx].type.ndim + axes = ( + tuple(range(ndim)) + if car_op.axis is None + else tuple(sorted(car_op.axis)) + ) + acc_dtype = ( + car_op.acc_dtype + if car_op.acc_dtype is not None + else car_node.outputs[0].type.dtype + ) + reduced_spec_by_idx[out_idx] = ( + car_op.scalar_op, + axes, + car_op.scalar_op.identity, + acc_dtype, + ) + fgraph_outputs[out_idx] = car_op(node.outputs[out_idx]) + + reduced_spec = ( + tuple( + reduced_spec_by_idx.get(out_idx) + for out_idx in range(len(node.outputs)) + ) + if reduced_outputs + else () + ) + # indexed_inputs_spec: ((read_positions, axis) | None, ...) # indexed_outputs_spec: ((write_positions, axis, "inc"|"set") | None, ...) indexed_inputs_spec = tuple( @@ -743,6 +882,7 @@ def _has_non_write_clients(out_idx): destroy_map=fgraph_destroy_map, indexed_inputs=indexed_inputs_spec, indexed_outputs=indexed_outputs_spec, + reduced_outputs=reduced_spec, )(*outer_inputs, return_list=True) replacements = [] @@ -751,6 +891,10 @@ def _has_non_write_clients(out_idx): replacements.append( (write_targets[out_idx].outputs[0], new_outs[out_idx]) ) + elif out_idx in reduced_outputs: + replacements.append( + (reduced_outputs[out_idx].outputs[0], new_outs[out_idx]) + ) else: replacements.append((node.outputs[out_idx], new_outs[out_idx])) diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py index 505961d1c1..73e2e32ab3 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_indexed_elemwise.py @@ -603,3 +603,182 @@ def test_runtime_broadcast_on_index_dim(self): # idx=1, y=5 — should error (shape mismatch, no static broadcast info) with pytest.raises(Exception): fn(np.zeros(100), np.zeros(1, dtype=np.int64), np.zeros(5)) + + +def assert_reduce_fused(fn): + """Assert the graph contains an IndexedElemwise with a fused reduction.""" + nodes = [n for n in fn.maker.fgraph.toposort() if isinstance(n.op, IndexedElemwise)] + assert nodes, "IndexedElemwise not found in fused graph" + assert any( + any(r is not None for r in n.op.reduced_outputs) for n in nodes + ), "No fused reduction (reduced_outputs) found" + + +class TestReductionFusion: + """Reductions (CAReduce) fused into the Elemwise loop, no indexing.""" + + @pytest.mark.parametrize( + "axis", [None, 0, 1, 2, (0, 2), (0, 1), (1, 2)], ids=str + ) + def test_sum_axes(self, axis): + rng = np.random.default_rng(0) + x = pt.tensor3("x") + y = pt.tensor3("y") + out = pt.sum(pt.exp(x) + y, axis=axis) + fn, fn_u = fused_and_unfused([x, y], out) + assert_reduce_fused(fn) + xv, yv = rng.normal(size=(3, 4, 5)), rng.normal(size=(3, 4, 5)) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_prod(self, axis): + rng = np.random.default_rng(1) + x = pt.matrix("x") + out = pt.prod(pt.exp(x * 0.1), axis=axis) + fn, fn_u = fused_and_unfused([x], out) + assert_reduce_fused(fn) + xv = rng.normal(size=(4, 5)) + np.testing.assert_allclose(fn(xv), fn_u(xv), rtol=1e-8) + + @pytest.mark.parametrize("reduce_fn", [pt.max, pt.min], ids=["max", "min"]) + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_max_min(self, reduce_fn, axis): + rng = np.random.default_rng(2) + x = pt.matrix("x") + y = pt.matrix("y") + out = reduce_fn(x + y, axis=axis) + fn, fn_u = fused_and_unfused([x, y], out) + assert_reduce_fused(fn) + xv, yv = rng.normal(size=(6, 7)), rng.normal(size=(6, 7)) + np.testing.assert_allclose(fn(xv, yv), fn_u(xv, yv), rtol=1e-10) + + @pytest.mark.parametrize("reduce_fn", [pt.all, pt.any], ids=["all", "any"]) + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_all_any(self, reduce_fn, axis): + rng = np.random.default_rng(3) + x = pt.matrix("x", dtype="bool") + y = pt.matrix("y", dtype="bool") + out = reduce_fn(x & y, axis=axis) + fn, fn_u = fused_and_unfused([x, y], out) + assert_reduce_fused(fn) + xv = rng.integers(0, 2, size=(4, 5)).astype(bool) + yv = rng.integers(0, 2, size=(4, 5)).astype(bool) + np.testing.assert_array_equal(fn(xv, yv), fn_u(xv, yv)) + + @pytest.mark.parametrize("dtype", ["int8", "int32", "uint8"]) + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_sum_acc_dtype_widening(self, dtype, axis): + """Sum of small int dtype accumulates in a wider acc_dtype.""" + rng = np.random.default_rng(4) + x = pt.matrix("x", dtype=dtype) + out = pt.sum(x + x, axis=axis) + fn, fn_u = fused_and_unfused([x], out) + assert_reduce_fused(fn) + info = np.iinfo(dtype) + xv = rng.integers(0, min(info.max // 2, 50), size=(40, 40)).astype(dtype) + np.testing.assert_array_equal(fn(xv), fn_u(xv)) + # Result must be the wide acc dtype, not overflow the input dtype + assert fn(xv).dtype == fn_u(xv).dtype + + def test_scalar_and_1d(self): + rng = np.random.default_rng(5) + x = pt.vector("x") + out = pt.sum(pt.exp(x)) + fn, fn_u = fused_and_unfused([x], out) + assert_reduce_fused(fn) + xv = rng.normal(size=(17,)) + np.testing.assert_allclose(fn(xv), fn_u(xv), rtol=1e-10) + + def test_non_c_contiguous_input(self): + """Reduction over a transposed (non-C-contiguous) intermediate.""" + rng = np.random.default_rng(6) + x = pt.matrix("x") + out = pt.sum((x + 1.0).T, axis=0) + fn, fn_u = fused_and_unfused([x], out) + xv = rng.normal(size=(8, 5)) + np.testing.assert_allclose(fn(xv), fn_u(xv), rtol=1e-10) + + +class TestReductionWithIndexing: + """Reductions composed with indexed reads in a single fused loop.""" + + @pytest.mark.parametrize("axis", [None, 0, 1], ids=str) + def test_sum_gather(self, axis): + rng = np.random.default_rng(10) + x = pt.matrix("x") + y = pt.matrix("y") + idx = pt.lvector("idx") + out = pt.sum(x[idx] + y, axis=axis) + fn, fn_u = fused_and_unfused([x, y, idx], out) + assert_reduce_fused(fn) + xv = rng.normal(size=(8, 5)) + yv = rng.normal(size=(4, 5)) + idxv = rng.integers(0, 8, size=4) + np.testing.assert_allclose(fn(xv, yv, idxv), fn_u(xv, yv, idxv), rtol=1e-10) + + def test_max_gather(self): + rng = np.random.default_rng(11) + x = pt.matrix("x") + idx = pt.lvector("idx") + out = pt.max(pt.exp(x[idx]), axis=0) + fn, fn_u = fused_and_unfused([x, idx], out) + assert_reduce_fused(fn) + xv = rng.normal(size=(8, 5)) + idxv = rng.integers(0, 8, size=4) + np.testing.assert_allclose(fn(xv, idxv), fn_u(xv, idxv), rtol=1e-10) + + +class TestReductionMultiOutput: + """Reduction plus direct use of the same Elemwise output (duplication).""" + + def test_sum_and_direct(self): + rng = np.random.default_rng(20) + x = pt.matrix("x") + y = pt.matrix("y") + f = pt.exp(x) + y + out = [pt.sum(f, axis=0), f] + fn, fn_u = fused_and_unfused([x, y], out) + assert_reduce_fused(fn) + xv, yv = rng.normal(size=(4, 5)), rng.normal(size=(4, 5)) + r, ru = fn(xv, yv), fn_u(xv, yv) + np.testing.assert_allclose(r[0], ru[0], rtol=1e-10) + np.testing.assert_allclose(r[1], ru[1], rtol=1e-10) + + def test_two_reductions_same_source(self): + """sum and max of the same elemwise output (two CAReduce clients).""" + rng = np.random.default_rng(21) + x = pt.matrix("x") + f = pt.exp(x) + out = [pt.sum(f, axis=0), pt.max(f, axis=0)] + fn, fn_u = fused_and_unfused([x], out) + xv = rng.normal(size=(4, 5)) + r, ru = fn(xv), fn_u(xv) + np.testing.assert_allclose(r[0], ru[0], rtol=1e-10) + np.testing.assert_allclose(r[1], ru[1], rtol=1e-10) + + +class TestReductionPythonMode: + """The fused op evaluates correctly outside JIT (OpFromGraph.perform).""" + + def test_perform_matches(self): + rng = np.random.default_rng(30) + x = pt.matrix("x") + y = pt.matrix("y") + idx = pt.lvector("idx") + fn = pytensor.function( + [x, y, idx], pt.sum(x[idx] + y, axis=1), mode=NUMBA_MODE, trust_input=True + ) + node = next( + n for n in fn.maker.fgraph.toposort() if isinstance(n.op, IndexedElemwise) + ) + # Re-apply the exact fused op and evaluate it via OpFromGraph.perform + fresh = [inp.type() for inp in node.inputs] + perform_fn = pytensor.function( + fresh, node.op(*fresh, return_list=True), mode="FAST_COMPILE" + ) + xv = rng.normal(size=(8, 5)) + yv = rng.normal(size=(4, 5)) + idxv = rng.integers(0, 8, size=4) + np.testing.assert_allclose( + perform_fn(xv, yv, idxv)[0], np.sum(xv[idxv] + yv, axis=1), rtol=1e-10 + ) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 7ac8f4c65b..dac8690abd 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -249,7 +249,11 @@ def _raise_on_opt_error(self): "add_mul_fusion", "inplace", ], - exclude=["cxx_only", "BlasOpt"], + # Exclude both careduce-fusion paths so reductions stay unfused here: + # cxx_only covers local_careduce_fusion (C backend); the indexed/reduce + # fusion is the Numba-specific equivalent. This class tests the generic + # Composite fusion, independent of the active backend. + exclude=["cxx_only", "BlasOpt", "fuse_indexed_into_elemwise"], ) mode = Mode(get_default_mode().linker, rewrites) _shared = staticmethod(shared) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 83ec5a5e10..f33a20013b 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -2826,7 +2826,13 @@ class TestLocalSumProd: """Test sum/prod rewrites.""" def setup_method(self): - self.mode = get_default_mode().including("canonicalize", "specialize") + # Exclude the Numba reduction fusion so CAReduce nodes stay visible in + # the toposort for the structural assertions below. + self.mode = ( + get_default_mode() + .including("canonicalize", "specialize") + .excluding("fuse_indexed_into_elemwise") + ) def test_local_sum_prod_of_scalar_mul(self): # Test the rewrite `local_sum_prod_mul_by_scalar` for both Sum and @@ -3335,7 +3341,9 @@ def test_local_prod_of_div(self): c_val = rng.standard_normal((2, 2, 2)).astype(config.floatX) d_val = np.asarray(rng.standard_normal(), config.floatX) - default_mode = get_default_mode() + # Exclude the Numba reduction fusion so the outer reduction op stays + # visible in the toposort for the structural assertions below. + default_mode = get_default_mode().excluding("fuse_indexed_into_elemwise") # `FusionOptimizer` is included to make sure that `expected_outer_operator` # remains the same for all rewrite modes. mode_with_rewrite = default_mode.including( @@ -3392,8 +3400,12 @@ def test_local_prod_of_div(self): class TestLocalReduce: def setup_method(self): - self.mode = get_default_mode().including( - "canonicalize", "specialize", "uncanonicalize" + # Exclude the Numba reduction fusion so CAReduce nodes stay visible in + # the toposort for the structural assertions below. + self.mode = ( + get_default_mode() + .including("canonicalize", "specialize", "uncanonicalize") + .excluding("fuse_indexed_into_elemwise") ) def test_local_reduce_broadcast_all_0(self): diff --git a/tests/tensor/rewriting/test_uncanonicalize.py b/tests/tensor/rewriting/test_uncanonicalize.py index 9d5011b6db..e842703d45 100644 --- a/tests/tensor/rewriting/test_uncanonicalize.py +++ b/tests/tensor/rewriting/test_uncanonicalize.py @@ -23,8 +23,12 @@ class TestMinMax: def setup_method(self): - self.mode = pytensor.compile.mode.get_default_mode().including( - "canonicalize", "fast_run" + # Exclude the Numba reduction fusion so the Max/Min CAReduce nodes stay + # visible in the toposort for the structural assertions below. + self.mode = ( + pytensor.compile.mode.get_default_mode() + .including("canonicalize", "fast_run") + .excluding("fuse_indexed_into_elemwise") ) def test_optimization_min(self):