Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytensor/link/numba/dispatch/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pytensor.link.numba.dispatch.vectorize_codegen import (
NO_INDEXED_INPUTS,
NO_INDEXED_OUTPUTS,
NO_REDUCE_OUTPUTS,
NO_SIZE,
_jit_options,
_vectorized,
Expand Down Expand Up @@ -96,6 +97,7 @@ def impl(*inputs_and_core_shapes):
NO_SIZE,
NO_INDEXED_INPUTS,
NO_INDEXED_OUTPUTS,
NO_REDUCE_OUTPUTS,
)

return impl
Expand Down
247 changes: 222 additions & 25 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pytensor.link.numba.dispatch.vectorize_codegen import (
NO_INDEXED_INPUTS,
NO_INDEXED_OUTPUTS,
NO_REDUCE_OUTPUTS,
NO_SIZE,
_jit_options,
_vectorized,
Expand Down Expand Up @@ -140,6 +141,38 @@ def scalar_in_place_fn_Minimum(op, idx, res, arr):
]


# Scalar reduce ops that ``accumulate_into_slice`` (and thus fused reductions /
# indexed inc) supports, paired with the in-place form used on the output slice.
# Augmented assignment is used where Numba supports it directly on 0-d arrays
# (``out`` is the core output slice, which is 0-d for scalar cores); Maximum and
# Minimum lack an augmented operator so they use a ufunc + ellipsis-set instead.
_SLICE_ACCUMULATE = {
Add: "{out} += {inner}",
Mul: "{out} *= {inner}",
AND: "{out} &= {inner}",
OR: "{out} |= {inner}",
XOR: "{out} ^= {inner}",
Maximum: "{out}[...] = np.maximum({out}, {inner})",
Minimum: "{out}[...] = np.minimum({out}, {inner})",
}


def accumulate_into_slice(scalar_op, out: str, inner: str) -> list[str]:
"""In-place accumulation lines for a (possibly 0-d) output slice ``out``.

Unlike ``scalar_in_place_fn`` (which indexes ``res[idx]`` and breaks on 0-d
cores), this operates on the whole slice so it is valid for both scalar
(0-d) and array cores. Used by both fused reductions and indexed ``inc``.
"""
try:
template = _SLICE_ACCUMULATE[type(scalar_op)]
except KeyError:
raise NotImplementedError(
f"No fused-reduction accumulation for scalar op {scalar_op}"
)
return [template.format(out=out, inner=inner)]


@intrinsic
def _address_as_void_pointer(typingctx, src):
"""Returns a void pointer from a given memory address."""
Expand Down Expand Up @@ -695,6 +728,7 @@ def impl(*inputs):
NO_SIZE,
NO_INDEXED_INPUTS,
NO_INDEXED_OUTPUTS,
NO_REDUCE_OUTPUTS,
)

return impl
Expand All @@ -715,6 +749,86 @@ def impl(*inputs):
return elemwise, elemwise_key


def _reduce_identity(identity, acc_dtype):
"""Coerce a reduction identity to a concrete numpy scalar in ``acc_dtype``.

Non-finite identities (``±inf`` for max/min) become the dtype's bounds for
integer accumulators, mirroring ``create_multiaxis_reducer``.
"""
acc_dtype = np.dtype(acc_dtype)
if acc_dtype.kind in "ui" and not np.isfinite(identity):
identity = (
np.iinfo(acc_dtype).max
if np.isposinf(identity)
else np.iinfo(acc_dtype).min
)
return acc_dtype.type(identity)


def _build_reduce_impl_src(nout, post_specs):
"""Build the ``@overload`` impl source for a fused op with reduction outputs.

Calls ``_vectorized`` (which produces a keepdims, size-1-on-reduced-axes
buffer in ``acc_dtype``) then, per reduction output, squeezes the reduced
axes back out and casts to the true output dtype. ``post_specs[i]`` is
``None`` for a passthrough output, else ``(kept_axes, out_dtype, cast_needed)``.
"""
code: list[str | CODE_TOKEN] = [
"def fused_elemwise_impl(*outer_inputs):",
CODE_TOKEN.INDENT,
"raw = _vectorized(",
CODE_TOKEN.INDENT,
"core_op_fn,",
"input_bc_patterns_enc,",
"output_bc_patterns_enc,",
"output_dtypes_enc,",
"inplace_pattern_enc,",
"True,",
"(),",
"outer_inputs,",
"core_output_shapes,",
"NO_SIZE,",
"indexed_inputs_enc,",
"indexed_outputs_enc,",
"reduce_outputs_enc,",
CODE_TOKEN.DEDENT,
")",
]

def src_access(i):
return "raw" if nout == 1 else f"raw[{i}]"

out_syms = []
for i in range(nout):
sym = f"o{i}"
out_syms.append(sym)
src = src_access(i)
spec = post_specs[i]
if spec is None:
code.append(f"{sym} = {src}")
continue
kept_axes, out_dtype, cast_needed = spec
np_dtype = "bool_" if out_dtype == "bool" else out_dtype
if not kept_axes:
# Full reduction → 0-d array of the single accumulated value.
code.append(f"{sym} = np.array({src}.ravel()[0], dtype=np.{np_dtype})")
else:
shape_expr = create_tuple_string(
tuple(f"{src}.shape[{k}]" for k in kept_axes)
)
expr = f"{src}.reshape({shape_expr})"
if cast_needed:
expr = f"{expr}.astype(np.{np_dtype})"
code.append(f"{sym} = {expr}")

if nout == 1:
code.append(f"return {out_syms[0]}")
else:
code.append(f"return {create_tuple_string(tuple(out_syms))}")
code.append(CODE_TOKEN.DEDENT)
return build_source_code(code)


@register_funcify_and_cache_key(IndexedElemwise)
def numba_funcify_IndexedElemwise(op, node, **kwargs):
"""Generate fused Elemwise Numba code with indexed reads and updates.
Expand All @@ -739,6 +853,7 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs):
n_indices = len(indexed_inputs)
nin_elemwise = len(elemwise_node.inputs)
nout = len(elemwise_node.outputs)
reduced_outputs = op.reduced_outputs or ((None,) * nout)

inc_outputs = frozenset(
out_idx
Expand All @@ -747,14 +862,60 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs):
for out_idx in entry[0]
if entry[2] == "inc"
)
reduced_idxs = frozenset(i for i, r in enumerate(reduced_outputs) if r is not None)
assert not (inc_outputs & reduced_idxs), (
"An output cannot be both an indexed write and a reduction"
)

# accum_fns bake per-output in-place accumulation into store_core_outputs:
# indexed `inc` writes accumulate with +=; reductions use their scalar op
# (Add/Mul/Maximum/...). Both go through accumulate_into_slice so they are
# valid on 0-d (scalar core) slices. Disjoint output indices.
accum_fns: dict = {}
for out_idx in inc_outputs:
accum_fns[out_idx] = lambda out, inner: [f"{out} += {inner}"]
for i in reduced_idxs:
accum_fns[i] = (
lambda out, inner, _op=reduced_outputs[i][0]: accumulate_into_slice(
_op, out, inner
)
)

core_op_fn = store_core_outputs(
scalar_op_fn, nin=nin_elemwise, nout=nout, inc_outputs=inc_outputs
scalar_op_fn, nin=nin_elemwise, nout=nout, accum_fns=accum_fns
)

input_bc_patterns = tuple(inp.type.broadcastable for inp in elemwise_node.inputs)
output_bc_patterns = tuple(out.type.broadcastable for out in elemwise_node.outputs)
output_dtypes = tuple(out.type.dtype for out in node.outputs)

# Reduction outputs keep their reduced axes as bc=True (size-1, keepdims) and
# are accumulated in acc_dtype; post_specs squeezes + casts them back to the
# true CAReduce output below.
output_bc_patterns_list = []
output_dtypes_list = []
reduce_identities = []
post_specs: list = []
has_reductions = bool(reduced_idxs)
for i, inner_out in enumerate(elemwise_node.outputs):
spec = reduced_outputs[i]
if spec is None:
output_bc_patterns_list.append(inner_out.type.broadcastable)
output_dtypes_list.append(node.outputs[i].type.dtype)
post_specs.append(None)
continue
_reduce_op, axes, identity, acc_dtype = spec
bc = list(inner_out.type.broadcastable)
for ax in axes:
bc[ax] = True
output_bc_patterns_list.append(tuple(bc))
output_dtypes_list.append(str(np.dtype(acc_dtype)))
reduce_identities.append((i, _reduce_identity(identity, acc_dtype)))
kept_axes = tuple(d for d in range(len(bc)) if d not in axes)
out_dtype = node.outputs[i].type.dtype
cast_needed = np.dtype(acc_dtype) != np.dtype(out_dtype)
post_specs.append((kept_axes, out_dtype, cast_needed))

output_bc_patterns = tuple(output_bc_patterns_list)
output_dtypes = tuple(output_dtypes_list)
inplace_pattern = tuple(elemwise_node.op.inplace_pattern.items())
core_output_shapes = tuple(() for _ in range(nout))

Expand All @@ -768,36 +929,71 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs):
inplace_pattern_enc = encode_literals(inplace_pattern)
indexed_inputs_enc = encode_literals((indexed_inputs, idx_broadcastable))
indexed_outputs_enc = encode_literals(indexed_outputs)
reduce_outputs_enc = encode_literals(tuple(reduce_identities))

def indexed_elemwise_fn(*outer_inputs):
raise NotImplementedError(
"IndexedElemwise cannot be evaluated in Python (non-JIT) mode."
)
# Python-mode fallback (e.g. Numba's ``eval_python_only`` path, which
# runs the funcified function without JIT). Evaluate the inner fgraph
# faithfully, exactly like ``OpFromGraph.perform`` — so the fused op
# still produces valid results when executed directly in Python. The
# JIT path never reaches this body: the ``@overload`` below supplies the
# vectorized loop implementation.
res = op.fn(*outer_inputs)
return res[0] if len(res) == 1 else tuple(res)

if not has_reductions:

@overload(indexed_elemwise_fn, jit_options=_jit_options)
def ov_indexed_elemwise_fn(*outer_inputs):
def impl(*outer_inputs):
return _vectorized(
core_op_fn,
input_bc_patterns_enc,
output_bc_patterns_enc,
output_dtypes_enc,
inplace_pattern_enc,
True, # allow_core_scalar
(), # constant_inputs
outer_inputs,
core_output_shapes,
NO_SIZE,
indexed_inputs_enc,
indexed_outputs_enc,
reduce_outputs_enc,
)

@overload(indexed_elemwise_fn, jit_options=_jit_options)
def ov_indexed_elemwise_fn(*outer_inputs):
def impl(*outer_inputs):
return _vectorized(
core_op_fn,
input_bc_patterns_enc,
output_bc_patterns_enc,
output_dtypes_enc,
inplace_pattern_enc,
True, # allow_core_scalar
(), # constant_inputs
outer_inputs,
core_output_shapes,
NO_SIZE,
indexed_inputs_enc,
indexed_outputs_enc,
)
return impl
else:
impl_src = _build_reduce_impl_src(nout, post_specs)
impl_fn = compile_numba_function_src(
impl_src,
"fused_elemwise_impl",
{
**globals(),
"core_op_fn": core_op_fn,
"input_bc_patterns_enc": input_bc_patterns_enc,
"output_bc_patterns_enc": output_bc_patterns_enc,
"output_dtypes_enc": output_dtypes_enc,
"inplace_pattern_enc": inplace_pattern_enc,
"core_output_shapes": core_output_shapes,
"indexed_inputs_enc": indexed_inputs_enc,
"indexed_outputs_enc": indexed_outputs_enc,
"reduce_outputs_enc": reduce_outputs_enc,
},
)

return impl
@overload(indexed_elemwise_fn, jit_options=_jit_options)
def ov_indexed_elemwise_fn(*outer_inputs):
return impl_fn

cache_version = 1
cache_version = 2
if scalar_cache_key is None:
key = None
else:
reduced_key = tuple(
(type(r[0]).__name__, r[1], str(np.dtype(r[3]))) if r is not None else None
for r in reduced_outputs
)
key = str(
(
type(op),
Expand All @@ -808,6 +1004,7 @@ def impl(*outer_inputs):
indexed_inputs,
idx_broadcastable,
indexed_outputs,
reduced_key,
scalar_cache_key,
)
)
Expand Down
2 changes: 2 additions & 0 deletions pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytensor.link.numba.dispatch.vectorize_codegen import (
NO_INDEXED_INPUTS,
NO_INDEXED_OUTPUTS,
NO_REDUCE_OUTPUTS,
NO_SIZE,
_jit_options,
_vectorized,
Expand Down Expand Up @@ -490,6 +491,7 @@ def impl(core_shape, rng, size, *dist_params):
else numba_ndarray.to_fixed_tuple(size, size_len),
NO_INDEXED_INPUTS,
NO_INDEXED_OUTPUTS,
NO_REDUCE_OUTPUTS,
)
return rng, draws

Expand Down
Loading
Loading