From 60c0ae3fdcf65254b2d05a9c35365c1964f90bf5 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sun, 3 May 2026 13:35:39 +0200 Subject: [PATCH] numba: inline="always" to speedup trivial op compilation --- pytensor/link/numba/dispatch/basic.py | 27 +++++---- pytensor/link/numba/dispatch/elemwise.py | 6 +- pytensor/link/numba/dispatch/scalar.py | 46 +++++++-------- pytensor/link/numba/dispatch/shape.py | 12 ++-- pytensor/link/numba/dispatch/tensor_basic.py | 10 ++-- .../link/numba/dispatch/vectorize_codegen.py | 4 +- tests/benchmarks/test_compilation.py | 56 +++++++++++++------ tests/tensor/rewriting/test_basic.py | 5 +- tests/tensor/test_shape.py | 10 ++-- 9 files changed, 103 insertions(+), 73 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 6a01532380..68fc2d6524 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -58,7 +58,7 @@ def _filter_numba_warnings(): def numba_njit( - *args, fastmath=None, final_function: bool = False, **kwargs + *args, fastmath=None, inline=None, final_function: bool = False, **kwargs ) -> Callable: """A thin wrapper around `numba.njit`. @@ -88,6 +88,9 @@ def numba_njit( kwargs.setdefault("no_cpython_wrapper", True) kwargs.setdefault("no_cfunc_wrapper", True) + if inline is not None: + kwargs["inline"] = inline + if len(args) > 0 and callable(args[0]): return _njit(*args[1:], fastmath=fastmath, **kwargs)(args[0]) # type: ignore else: @@ -448,15 +451,19 @@ def numba_funcify_ensure_cache(op, *args, **kwargs) -> tuple[Callable, str | Non if config.numba__cache and config.compiler_verbose: print(f"{op} of type {type(op)} will not be cached by PyTensor.\n") # noqa: T201 return jitable_func, None - else: - op_name = jitable_func.__name__ - cached_func = compile_numba_function_src( - src=f"def {op_name}(*args): return jitable_func(*args)", - function_name=op_name, - global_env=globals() | {"jitable_func": jitable_func}, - cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}", - ) - return numba_njit(cached_func, cache=True), cache_key + + # Inline functions get baked into the caller's cache entry and can't be independently cached + if getattr(jitable_func, "targetoptions", {}).get("inline") == "always": + return jitable_func, cache_key + + op_name = jitable_func.__name__ + cached_func = compile_numba_function_src( + src=f"def {op_name}(*args): return jitable_func(*args)", + function_name=op_name, + global_env=globals() | {"jitable_func": jitable_func}, + cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}", + ) + return numba_njit(cached_func, cache=True), cache_key def cache_key_for_constant(data): diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 2308ddffdc..03e7ddd112 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -456,7 +456,7 @@ def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs): if new_order == (): # Special case needed because of https://github.com/numba/numba/issues/9933 - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def squeeze_to_0d(x): if not x.size == 1: raise ValueError( @@ -473,13 +473,13 @@ def squeeze_to_0d(x): new_shape = shape_template new_strides = strides_template - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def dimshuffle(x): return as_strided(np.asarray(x), shape=new_shape, strides=new_strides) else: - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def dimshuffle(x): old_shape = x.shape old_strides = x.strides diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 777b4d5a6c..39953e85f2 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -136,12 +136,12 @@ def {scalar_op_fn_name}({input_signature}): # Functions that call a function pointer can't be cached cache_key = None if cython_func else scalar_op_cache_key(op) - return numba_basic.numba_njit(scalar_op_fn), cache_key + return numba_basic.numba_njit(scalar_op_fn, inline="always"), cache_key @register_funcify_and_cache_key(Switch) def numba_funcify_Switch(op, node, **kwargs): - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def switch(condition, x, y): if condition: return x @@ -174,34 +174,34 @@ def numba_funcify_Pow(op, node, **kwargs): def pow(x, y): return x**y - # Numba power fails when exponents are discrete integers and fasthmath=True - # https://github.com/numba/numba/issues/9554 - fastmath = False if np.dtype(pow_dtype).kind in "ibu" else None - - return numba_basic.numba_njit(pow, fastmath=fastmath), scalar_op_cache_key( - op, cache_version=1 - ) + # Integer exponents break fastmath and inline (numba#9554) + integer_exp = np.dtype(pow_dtype).kind in "ibu" + return numba_basic.numba_njit( + pow, + fastmath=False if integer_exp else None, + inline=None if integer_exp else "always", + ), scalar_op_cache_key(op, cache_version=1) @register_funcify_and_cache_key(Add) def numba_funcify_Add(op, node, **kwargs): nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") - return numba_basic.numba_njit(nary_add_fn), scalar_op_cache_key(op) + return numba_basic.numba_njit(nary_add_fn, inline="always"), scalar_op_cache_key(op) @register_funcify_and_cache_key(Mul) def numba_funcify_Mul(op, node, **kwargs): nary_mul_fn = binary_to_nary_func(node.inputs, "mul", "*") - return numba_basic.numba_njit(nary_mul_fn), scalar_op_cache_key(op) + return numba_basic.numba_njit(nary_mul_fn, inline="always"), scalar_op_cache_key(op) @register_funcify_and_cache_key(Cast) def numba_funcify_Cast(op, node, **kwargs): dtype = np.dtype(op.o_type.dtype) - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def cast(x): return numba_basic.direct_cast(x, dtype) @@ -210,7 +210,7 @@ def cast(x): @register_funcify_and_cache_key(Identity) def numba_funcify_type_casting(op, **kwargs): - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def identity(x): return x @@ -219,7 +219,7 @@ def identity(x): @register_funcify_and_cache_key(Clip) def numba_funcify_Clip(op, **kwargs): - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def clip(x, min_val, max_val): if x < min_val: return min_val @@ -247,7 +247,7 @@ def numba_funcify_Composite(op, node, **kwargs): @register_funcify_and_cache_key(Second) def numba_funcify_Second(op, node, **kwargs): - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def second(x, y): return y @@ -256,7 +256,7 @@ def second(x, y): @register_funcify_and_cache_key(Reciprocal) def numba_funcify_Reciprocal(op, node, **kwargs): - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def reciprocal(x): # This is how the C-backend implementation works return np.divide(np.float32(1.0), x) @@ -275,7 +275,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs): "uint64": np.float64, }[inp_dtype] - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def sigmoid(x): # Can't negate uint float_x = numba_basic.direct_cast(x, upcast_uint_dtype) @@ -283,7 +283,7 @@ def sigmoid(x): else: - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def sigmoid(x): return 1 / (1 + np.exp(-x)) @@ -292,7 +292,7 @@ def sigmoid(x): @register_funcify_and_cache_key(GammaLn) def numba_funcify_GammaLn(op, node, **kwargs): - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def gammaln(x): return math.lgamma(x) @@ -301,7 +301,7 @@ def gammaln(x): @register_funcify_and_cache_key(Log1mexp) def numba_funcify_Log1mexp(op, node, **kwargs): - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def logp1mexp(x): if x < np.log(0.5): return np.log1p(-np.exp(x)) @@ -317,7 +317,7 @@ def numba_funcify_Erf(op, node, **kwargs): # Complex not supported by numba return numba_funcify_ScalarOp(op, node=node, **kwargs) - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def erf(x): return math.erf(x) @@ -326,7 +326,7 @@ def erf(x): @register_funcify_and_cache_key(Erfc) def numba_funcify_Erfc(op, **kwargs): - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def erfc(x): return math.erfc(x) @@ -347,7 +347,7 @@ def numba_funcify_Softplus(op, node, **kwargs): upcast_uint_dtype = None out_dtype = np.dtype(node.outputs[0].type.dtype) - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def softplus(x): if x < -37.0: value = np.exp(x) diff --git a/pytensor/link/numba/dispatch/shape.py b/pytensor/link/numba/dispatch/shape.py index b6a5533809..0cb14f3e5f 100644 --- a/pytensor/link/numba/dispatch/shape.py +++ b/pytensor/link/numba/dispatch/shape.py @@ -12,7 +12,7 @@ @register_funcify_default_op_cache_key(Shape) def numba_funcify_Shape(op, **kwargs): - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def shape(x): return np.asarray(np.shape(x)) @@ -23,7 +23,7 @@ def shape(x): def numba_funcify_Shape_i(op, **kwargs): i = op.i - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def shape_i(x): return np.asarray(np.shape(x)[i]) @@ -36,7 +36,7 @@ def numba_funcify_SpecifyShape(op, node, **kwargs): shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] func_conditions = [ - f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'" + f"assert x.shape[{i}] == {eval_dim_name}, 'SpecifyShape: shape mismatch in dim {i}'" for i, (node_dim_input, eval_dim_name) in enumerate( zip(shape_inputs, shape_input_names, strict=True) ) @@ -52,7 +52,7 @@ def specify_shape(x, {", ".join(shape_input_names)}): ) specify_shape = compile_function_src(func, "specify_shape", globals()) - return numba_basic.numba_njit(specify_shape) + return numba_basic.numba_njit(specify_shape, inline="always") @register_funcify_default_op_cache_key(Reshape) @@ -61,13 +61,13 @@ def numba_funcify_Reshape(op, **kwargs): if ndim == 0: - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def reshape(x, shape): return np.asarray(x.item()) else: - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def reshape(x, shape): # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed. return np.reshape( diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index a5babae7de..cf8826397b 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -50,7 +50,7 @@ def allocempty({", ".join(shape_var_names)}): alloc_def_src, "allocempty", globals() | {"np": np, "dtype": np.dtype(op.dtype)} ) - return numba_basic.numba_njit(alloc_fn) + return numba_basic.numba_njit(alloc_fn, inline="always") @register_funcify_and_cache_key(Alloc) @@ -221,12 +221,14 @@ def makevector({", ".join(input_names)}): globals() | {"np": np, "dtype": dtype}, ) - return numba_basic.numba_njit(makevector_fn) + # Numba can't inline closures with more than 30 arguments + inline = "always" if len(input_names) <= 30 else None + return numba_basic.numba_njit(makevector_fn, inline=inline) @register_funcify_default_op_cache_key(TensorFromScalar) def numba_funcify_TensorFromScalar(op, **kwargs): - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def tensor_from_scalar(x): return np.array(x) @@ -235,7 +237,7 @@ def tensor_from_scalar(x): @register_funcify_default_op_cache_key(ScalarFromTensor) def numba_funcify_ScalarFromTensor(op, **kwargs): - @numba_basic.numba_njit + @numba_basic.numba_njit(inline="always") def scalar_from_tensor(x): return x.item() diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index f804c0c04c..13bd80f9d2 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -58,7 +58,9 @@ def store_core_outputs({inp_signature}, {out_signature}): "store_core_outputs", {**globals(), **global_env}, ) - return numba_basic.numba_njit(func) + # Numba can't inline closures with more than 30 arguments + inline = "always" if nin + nout <= 30 else None + return numba_basic.numba_njit(func, inline=inline) _jit_options = { diff --git a/tests/benchmarks/test_compilation.py b/tests/benchmarks/test_compilation.py index c748ba868f..3ff044681c 100644 --- a/tests/benchmarks/test_compilation.py +++ b/tests/benchmarks/test_compilation.py @@ -1,4 +1,7 @@ -from contextlib import nullcontext +import shutil +import tempfile +from contextlib import contextmanager, nullcontext +from pathlib import Path import numpy as np import pytest @@ -24,6 +27,20 @@ ) +@contextmanager +def _fresh_numba_cache(): + import pytensor.link.numba.cache as cache_mod + + original = cache_mod.NUMBA_CACHE_PATH + tmp = Path(tempfile.mkdtemp(prefix="bench_numba_cache_")) + cache_mod.NUMBA_CACHE_PATH = tmp + try: + yield tmp + finally: + cache_mod.NUMBA_CACHE_PATH = original + shutil.rmtree(tmp, ignore_errors=True) + + def create_radon_model( intercept_dist="normal", sigma_dist="halfnormal", centered=False ): @@ -183,10 +200,11 @@ def compile_and_call_once(): ) fn(x) - ctx = ( + cache_ctx = _fresh_numba_cache() if cache else nullcontext() + flag_ctx = ( config.change_flags(numba__cache=cache) if cache is not None else nullcontext() ) - with ctx: + with cache_ctx, flag_ctx: benchmark.pedantic(compile_and_call_once, rounds=5, iterations=1) @@ -201,23 +219,27 @@ def test_radon_model_compile_variants_benchmark( rng = np.random.default_rng(1) x = rng.normal(size=joined_inputs.type.shape).astype(config.floatX) - # Compile base function once to populate the cache - fn = function( - [joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True + cache_ctx = _fresh_numba_cache() if cache else nullcontext() + flag_ctx = ( + config.change_flags(numba__cache=cache) if cache is not None else nullcontext() ) - fn(x) + with cache_ctx, flag_ctx: + # Compile base function once to populate the cache + fn = function( + [joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True + ) + fn(x) - def compile_and_call_once(): - for joined_inputs, [model_logp, model_dlogp] in radon_model_variants: - fn = function( - [joined_inputs], [model_logp, model_dlogp], mode=mode, trust_input=True - ) - fn(x) + def compile_and_call_once(): + for joined_inputs, [model_logp, model_dlogp] in radon_model_variants: + fn = function( + [joined_inputs], + [model_logp, model_dlogp], + mode=mode, + trust_input=True, + ) + fn(x) - ctx = ( - config.change_flags(numba__cache=cache) if cache is not None else nullcontext() - ) - with ctx: benchmark.pedantic(compile_and_call_once, rounds=1, iterations=1) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index f5286b7156..1b3e435150 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -298,10 +298,7 @@ def test_inconsistent_shared(self, shape_unsafe): if shape_unsafe: assert not has_alloc # Error raised by SpecifyShape that is introduced due to static shape inference - with pytest.raises( - AssertionError, - match="SpecifyShape: dim 0 of input has shape 3, expected 6\\.", - ): + with pytest.raises(AssertionError, match="SpecifyShape"): f() else: assert has_alloc diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 21ec9800e2..a2712f5c56 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -496,14 +496,14 @@ def test_python_perform(self): f = pytensor.function([x, shape], y, mode=Mode("py")) assert f([1], (1,)) == [1] - with pytest.raises(AssertionError, match=r"SpecifyShape:.*"): + with pytest.raises(AssertionError, match="SpecifyShape"): assert f([1], (2,)) == [1] x = matrix() y = specify_shape(x, (None, 2)) f = pytensor.function([x], y, mode=Mode("py")) assert f(np.zeros((3, 2), dtype=config.floatX)).shape == (3, 2) - with pytest.raises(AssertionError, match=r"SpecifyShape:.*"): + with pytest.raises(AssertionError, match="SpecifyShape"): assert f(np.zeros((3, 3), dtype=config.floatX)) def test_bad_shape(self): @@ -517,7 +517,7 @@ def test_bad_shape(self): assert np.array_equal(f(xval), xval) xval = np.random.random(3).astype(config.floatX) - with pytest.raises(AssertionError, match=r"SpecifyShape:.*"): + with pytest.raises(AssertionError, match="SpecifyShape"): f(xval) assert isinstance( @@ -541,14 +541,14 @@ def test_bad_shape(self): for shape_ in [(4, 3), (2, 8)]: xval = np.random.random(shape_).astype(config.floatX) - with pytest.raises(AssertionError, match=r"SpecifyShape:.*"): + with pytest.raises(AssertionError, match="SpecifyShape"): f(xval) s = iscalar("s") f = pytensor.function([x, s], specify_shape(x, None, s), mode=self.mode) x_val = np.zeros((3, 2), dtype=config.floatX) assert f(x_val, 2).shape == (3, 2) - with pytest.raises(AssertionError, match=r"SpecifyShape:.*"): + with pytest.raises(AssertionError, match="SpecifyShape"): f(xval, 3) def test_infer_shape(self):