diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 4060c38365..33b2b29385 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -963,24 +963,66 @@ class SymbolicOp(OpFromGraph): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if "__props__" in cls.__dict__: - # MetaType installs props-only __hash__ and __eq__ which ignores the inner graph - # override with fgraph-aware version - cls.__hash__ = OpFromGraph.__hash__ - cls.__eq__ = OpFromGraph.__eq__ + # MetaType installs props-only __hash__/__eq__ that ignore the inner graph. + # Restore the SymbolicOp versions (fgraph-aware, and deferred-op-aware). + cls.__hash__ = SymbolicOp.__hash__ + cls.__eq__ = SymbolicOp.__eq__ + + def __hash__(self): + # A deferred SymbolicOp has no inner graph yet, so identify it by its type, + # props and static params rather than the (absent) frozen fgraph. + if getattr(self, "fgraph", None) is None: + props = tuple( + getattr(self, p) for p in getattr(type(self), "__props__", ()) + ) + return hash((type(self), props, self.static_params)) + return OpFromGraph.__hash__(self) + + def __eq__(self, other): + if self is other: + return True + if type(self) is not type(other): + return False + self_built = getattr(self, "fgraph", None) is not None + other_built = getattr(other, "fgraph", None) is not None + if self_built and other_built: + return OpFromGraph.__eq__(self, other) + if self_built != other_built: + return False + props = getattr(type(self), "__props__", ()) + return self.static_params == other.static_params and all( + getattr(self, p) == getattr(other, p) for p in props + ) @staticmethod def filter_inputs(*inputs): return inputs + def build_static_params(self, inputs): + """Hashable static information extracted from the actual input *values*. + + Some inner graphs depend on input information that is not captured by the + input *types* — most notably the concrete dimensions encoded by a ``size`` + vector, which determine the static (broadcastable) shape of the outputs. + Subclasses may override this to return such information from the actual + ``inputs``. The returned value is stored as ``self.static_params`` (so it is + available to :meth:`build_inner_graph`) and participates in the decision of + whether the Op must be rebuilt for a new set of inputs. + + The default returns ``None`` (the inner graph depends only on input types). + """ + return None + def build_inner_graph(self, *inputs) -> list[Variable]: raise NotImplementedError - def __init__(self, input_types=None, **kwargs): + def __init__(self, input_types=None, static_params=None, **kwargs): """Construct op for the given input Types. When input_types is None, construction is deferred until the first __call__, which inspects the actual input types and builds the graph. """ + self.static_params = static_params for prop in getattr(type(self), "__props__", ()): if prop in kwargs: setattr(self, prop, kwargs.pop(prop)) @@ -992,15 +1034,28 @@ def __init__(self, input_types=None, **kwargs): outputs = self.build_inner_graph(*dummy_inputs) super().__init__(dummy_inputs, outputs, **kwargs) - def __call__(self, *inputs, **kwargs): - inputs = self.filter_inputs(*inputs) - input_types = tuple(inp.type for inp in inputs) - - if hasattr(self, "fgraph") and input_types == tuple(self.input_types): - return super().__call__(*inputs, **kwargs) + def _resolve_op(self, inputs) -> SymbolicOp: + """Return the concrete (built) Op matching the given inputs. + Reuses ``self`` when its inner graph already matches the inputs' types and + static params; otherwise builds a new Op for them. + """ + input_types = tuple(inp.type for inp in inputs) + static_params = self.build_static_params(inputs) + if ( + hasattr(self, "fgraph") + and input_types == tuple(self.input_types) + and static_params == self.static_params + ): + return self init_kwargs = dict(self._init_kwargs) for prop in getattr(type(self), "__props__", ()): init_kwargs[prop] = getattr(self, prop) - op = type(self)(input_types=list(input_types), **init_kwargs) + return type(self)( + input_types=list(input_types), static_params=static_params, **init_kwargs + ) + + def __call__(self, *inputs, **kwargs): + inputs = self.filter_inputs(*inputs) + op = self._resolve_op(inputs) return super(SymbolicOp, op).__call__(*inputs, **kwargs) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 23a82ab270..ce340361a2 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -893,36 +893,51 @@ def numba_funcify_Dot(op, node, **kwargs): if x_dtype == numba_dot_dtype and y_dtype == numba_dot_dtype: @numba_basic.numba_njit - def dot(x, y): - return np.asarray(np.dot(x, y)) + def dot(x, y, out=None): + if out is None: + return np.asarray(np.dot(x, y)) + np.dot(x, y, out) + return out elif x_dtype == numba_dot_dtype and y_dtype != numba_dot_dtype: @numba_basic.numba_njit - def dot(x, y): - return np.asarray(np.dot(x, y.astype(numba_dot_dtype))) + def dot(x, y, out=None): + if out is None: + return np.asarray(np.dot(x, y.astype(numba_dot_dtype))) + np.dot(x, y.astype(numba_dot_dtype), out) + return out elif x_dtype != numba_dot_dtype and y_dtype == numba_dot_dtype: @numba_basic.numba_njit - def dot(x, y): - return np.asarray(np.dot(x.astype(numba_dot_dtype), y)) + def dot(x, y, out=None): + if out is None: + return np.asarray(np.dot(x.astype(numba_dot_dtype), y)) + np.dot(x.astype(numba_dot_dtype), y, out) + return out else: @numba_basic.numba_njit - def dot(x, y): - return np.asarray( - np.dot(x.astype(numba_dot_dtype), y.astype(numba_dot_dtype)) - ) + def dot(x, y, out=None): + if out is None: + return np.asarray( + np.dot(x.astype(numba_dot_dtype), y.astype(numba_dot_dtype)) + ) + np.dot(x.astype(numba_dot_dtype), y.astype(numba_dot_dtype), out) + return out - cache_version = 1 + cache_version = 2 if out_dtype == numba_dot_dtype: + # np.dot can write straight into the pre-allocated batch output slice. + dot.handles_out = True return dot, cache_version else: - + # Output needs a dtype cast np.dot can't do in place, so fall back to + # the copying store_core_outputs wrapper. @numba_basic.numba_njit def dot_with_cast(x, y): return dot(x, y).astype(out_dtype) @@ -935,14 +950,16 @@ def numba_funcify_BatchedDot(op, node, **kwargs): dtype = node.outputs[0].type.numpy_dtype @numba_basic.numba_njit - def batched_dot(x, y): + def batched_dot(x, y, out=None): # Numba does not support 3D matmul # https://github.com/numba/numba/issues/3804 - shape = x.shape[:-1] + y.shape[2:] - z0 = np.empty(shape, dtype=dtype) - for i in range(z0.shape[0]): - z0[i] = np.dot(x[i], y[i]) + if out is None: + shape = x.shape[:-1] + y.shape[2:] + out = np.empty(shape, dtype=dtype) + for i in range(out.shape[0]): + out[i] = np.dot(x[i], y[i]) - return z0 + return out - return batched_dot + batched_dot.handles_out = True + return batched_dot, 1 diff --git a/pytensor/link/numba/dispatch/linalg/constructors.py b/pytensor/link/numba/dispatch/linalg/constructors.py index 72ca52964a..892099207b 100644 --- a/pytensor/link/numba/dispatch/linalg/constructors.py +++ b/pytensor/link/numba/dispatch/linalg/constructors.py @@ -18,10 +18,13 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs): The generated code looks something like: - def block_diagonal(arr0, arr1, arr2): + def block_diagonal(arr0, arr1, arr2, out=None): out_r = arr0.shape[0] + arr1.shape[0] + arr2.shape[0] out_c = arr0.shape[1] + arr1.shape[1] + arr2.shape[1] - out = np.zeros((out_r, out_c), dtype=np.float64) + if out is None: + out = np.zeros((out_r, out_c), dtype=np.float64) + else: + out[:] = 0 r, c = 0, 0 rr, cc = arr0.shape @@ -46,11 +49,18 @@ def block_diagonal(arr0, arr1, arr2): arg_names = [f"arr{i}" for i in range(n_inp)] code = [ - f"def block_diagonal({', '.join(arg_names)}):", + f"def block_diagonal({', '.join(arg_names)}, out=None):", CODE_TOKEN.INDENT, f"out_r = {' + '.join(f'{a}.shape[0]' for a in arg_names)}", f"out_c = {' + '.join(f'{a}.shape[1]' for a in arg_names)}", + "if out is None:", + CODE_TOKEN.INDENT, f"out = np.zeros((out_r, out_c), dtype=np.{dtype})", + CODE_TOKEN.DEDENT, + "else:", + CODE_TOKEN.INDENT, + "out[:] = 0", + CODE_TOKEN.DEDENT, CODE_TOKEN.EMPTY_LINE, "r, c = 0, 0", ] @@ -73,5 +83,7 @@ def block_diagonal(arr0, arr1, arr2): globals() | {"np": np}, ) - cache_version = 1 - return numba_basic.numba_njit(block_diag), cache_version + block_diag = numba_basic.numba_njit(block_diag) + block_diag.handles_out = True + cache_version = 2 + return block_diag, cache_version diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index a73b9c783e..42e06cfd71 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -173,22 +173,25 @@ def core_MultinomialRV(op, node): dtype = op.dtype @numba_basic.numba_njit - def random_fn(rng, n, p): + def random_fn(rng, n, p, out=None): n_cat = p.shape[0] - draws = np.zeros(n_cat, dtype=dtype) + if out is None: + out = np.empty(n_cat, dtype=dtype) + out[:] = 0 remaining_p = np.float64(1.0) remaining_n = n for i in range(n_cat - 1): - draws[i] = rng.binomial(remaining_n, p[i] / remaining_p) - remaining_n -= draws[i] + out[i] = rng.binomial(remaining_n, p[i] / remaining_p) + remaining_n -= out[i] if remaining_n <= 0: break remaining_p -= p[i] if remaining_n > 0: - draws[n_cat - 1] = remaining_n - return draws + out[n_cat - 1] = remaining_n + return out - return random_fn + random_fn.handles_out = True + return random_fn, 1 @numba_core_rv_funcify.register(ptr.MvNormalRV) @@ -220,13 +223,16 @@ def core_DirichletRV(op, node): dtype = op.dtype @numba_basic.numba_njit - def random_fn(rng, alpha): - y = np.empty_like(alpha, dtype=dtype) + def random_fn(rng, alpha, out=None): + if out is None: + out = np.empty_like(alpha, dtype=dtype) for i in range(len(alpha)): - y[i] = rng.gamma(alpha[i], 1.0) - return y / y.sum() + out[i] = rng.gamma(alpha[i], 1.0) + out /= out.sum() + return out - return random_fn, 1 + random_fn.handles_out = True + return random_fn, 2 @numba_core_rv_funcify.register(ptr.GumbelRV) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 198075d3ea..f87cb012c1 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -3,17 +3,11 @@ from typing import Literal import numpy as np -from numpy import broadcast_shapes as np_broadcast_shapes -from numpy import einsum as np_einsum -from numpy import sqrt as np_sqrt -from numpy.linalg import cholesky as np_cholesky -from numpy.linalg import eigh as np_eigh -from numpy.linalg import svd as np_svd from pytensor.tensor import get_vector_length, specify_shape from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.math import sqrt -from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.op import RandomVariable, SymbolicRVOp from pytensor.tensor.random.utils import ( broadcast_params, normalize_size_param, @@ -852,7 +846,7 @@ def __call__(self, mu, kappa, size=None, **kwargs): vonmises = VonMisesRV() -class MvNormalRV(RandomVariable): +class MvNormalRV(SymbolicRVOp): r"""A multivariate normal random variable. The probability density function for `multivariate_normal` in term of its location parameter @@ -867,20 +861,20 @@ class MvNormalRV(RandomVariable): """ name = "multivariate_normal" - signature = "(n),(n,n)->(n)" - dtype = "floatX" + extended_signature = "[rng],[size],(n),(n,n)->[rng],(n)" _print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}") - __props__ = ("name", "signature", "dtype", "inplace", "method") + __props__ = ("method",) - def __init__(self, *args, method: Literal["cholesky", "svd", "eigh"], **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, *, method: Literal["cholesky", "svd", "eigh"] = "cholesky", **kwargs + ): if method not in ("cholesky", "svd", "eigh"): raise ValueError( f"Unknown method {method}. The method must be one of 'cholesky', 'svd', or 'eigh'." ) - self.method = method + super().__init__(method=method, **kwargs) - def __call__(self, mean, cov, size=None, method=None, **kwargs): + def __call__(self, mean, cov, size=None, method=None, rng=None, **kwargs): r""" "Draw samples from a multivariate normal distribution. Signature @@ -904,36 +898,47 @@ def __call__(self, mean, cov, size=None, method=None, **kwargs): """ if method is not None and method != self.method: - # Recreate Op with the new method - props = self._props_dict() - props["method"] = method - new_op = type(self)(**props) - return new_op.__call__(mean, cov, size=size, method=method, **kwargs) - return super().__call__(mean, cov, size=size, **kwargs) - - def rng_fn(self, rng, mean, cov, size): - if size is None: - size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + return type(self)(method=method)(mean, cov, size=size, rng=rng, **kwargs) + return super().__call__(mean, cov, size=size, rng=rng, **kwargs) + + def build_inner_graph(self, rng, size, mean, cov): + from pytensor.tensor.extra_ops import broadcast_shape + from pytensor.tensor.linalg import cholesky, eigh, svd + from pytensor.tensor.math import matvec, sqrt + from pytensor.tensor.type_other import NoneTypeT if self.method == "cholesky": - A = np_cholesky(cov) + a = cholesky(cov, lower=True) elif self.method == "svd": - A, s, _ = np_svd(cov) - A *= np_sqrt(s, out=s)[..., None, :] + u, s, _ = svd(cov) + a = u * sqrt(s)[..., None, :] + else: + w, v = eigh(cov) + a = v * sqrt(w)[..., None, :] + + core_shape = (cov.shape[-1],) + if isinstance(size.type, NoneTypeT): + batch_shape = broadcast_shape( + tuple(mean.shape)[:-1], + tuple(cov.shape)[:-2], + arrays_are_shapes=True, + ) else: - w, A = np_eigh(cov) - A *= np_sqrt(w, out=w)[..., None, :] - - out = rng.normal(size=(*size, mean.shape[-1])) - np_einsum( - "...ij,...j->...i", # numpy doesn't have a batch matrix-vector product - A, - out, - optimize=False, # Nothing to optimize with two operands, skip costly setup - out=out, + # Use the statically-known size dimensions (self.static_params) so the + # draws carry the correct static/broadcastable shape; fall back to the + # runtime size vector for dimensions that aren't statically known. + batch_shape = tuple( + runtime_dim if static_dim is None else static_dim + for static_dim, runtime_dim in zip( + self.static_params, size, strict=True + ) + ) + + next_rng, std_draws = normal( + 0.0, 1.0, size=(*batch_shape, *core_shape), rng=rng, return_next_rng=True ) - out += mean - return out + draws = mean + matvec(a, std_draws) + return [next_rng, draws] multivariate_normal = MvNormalRV(method="cholesky") diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 34479e142c..ff823b8bed 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -1,4 +1,5 @@ import abc +import re import warnings from collections.abc import Sequence from typing import Any, cast @@ -6,6 +7,7 @@ import numpy as np import pytensor +from pytensor.compile.builders import SymbolicOp from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable, equal_computations from pytensor.graph.op import Op @@ -481,6 +483,166 @@ def pushforward(self, inputs, outputs, eval_points): return [disconnected_type() for i in eval_points] +class SymbolicRVOp(SymbolicOp, RNGConsumerOp): + """A random variable whose draws are defined by a symbolic inner graph. + + Unlike `RandomVariable`, which carries an opaque ``rng_fn``, a + `SymbolicRVOp` expresses its draws as a graph of regular + `Op`\\s (e.g. another `RandomVariable` plus deterministic transforms). + Because the inner graph is inlined before backend compilation, every + backend gets the decomposed graph for free, and gradients follow through + the reparameterization. + + Subclasses implement :meth:`build_inner_graph` returning + ``[next_rng, draws]`` and set :attr:`extended_signature`, a gufunc-like + signature that uses the special tokens ``[rng]`` and ``[size]`` to mark the + rng and size inputs, e.g. ``"[rng],[size],(n),(n,n)->[rng],(n)"``. + """ + + extended_signature: str | None = None + inline = True + + @staticmethod + def filter_inputs(rng, size, *params): + return ( + rng, + normalize_size_param(size), + *(as_tensor_variable(p) for p in params), + ) + + @property + def signature(self) -> str | None: + """The signature with the special ``[rng]``/``[size]`` tokens removed.""" + extended_signature = self.extended_signature + if extended_signature is None: + return None + special_tokens = r"|".join((r"\[rng\],?", r"\[size\],?")) + signature = re.sub(special_tokens, "", extended_signature) + # Remove dangling commas + return re.sub(r",(?=[->])|,$", "", signature) + + @property + def ndims_params(self) -> Sequence[int] | None: + signature = self.signature + if signature is None: + return None + inputs_sig, _ = _parse_gufunc_signature(signature) + return [len(sig) for sig in inputs_sig] + + @property + def ndim_supp(self) -> int | None: + signature = self.signature + if signature is None: + return None + _, outputs_sig = _parse_gufunc_signature(signature) + return max(len(out_sig) for out_sig in outputs_sig) + + @staticmethod + def get_input_output_type_idxs(extended_signature): + """Parse ``extended_signature`` into rng/size/param and rng/output indices.""" + if extended_signature is None: + raise ValueError("extended_signature must be provided") + fake_signature = extended_signature.replace("[rng]", "(rng)").replace( + "[size]", "(size)" + ) + inputs_sig, outputs_sig = _parse_gufunc_signature(fake_signature) + + rng_in_idxs, size_idx, param_idxs = [], None, [] + for i, inp_sig in enumerate(inputs_sig): + if inp_sig == ("size",): + size_idx = i + elif inp_sig == ("rng",): + rng_in_idxs.append(i) + else: + param_idxs.append(i) + + rng_out_idxs, out_idxs = [], [] + for i, out_sig in enumerate(outputs_sig): + if out_sig == ("rng",): + rng_out_idxs.append(i) + else: + out_idxs.append(i) + + return ( + (tuple(rng_in_idxs), size_idx, tuple(param_idxs)), + (tuple(rng_out_idxs), tuple(out_idxs)), + ) + + def rng_params(self, node) -> tuple[Variable, ...]: + (rng_in_idxs, _, _), _ = self.get_input_output_type_idxs( + self.extended_signature + ) + return tuple(node.inputs[i] for i in rng_in_idxs) + + def size_param(self, node) -> Variable | None: + (_, size_idx, _), _ = self.get_input_output_type_idxs(self.extended_signature) + return node.inputs[size_idx] if size_idx is not None else None + + def dist_params(self, node) -> tuple[Variable, ...]: + (_, _, param_idxs), _ = self.get_input_output_type_idxs(self.extended_signature) + return tuple(node.inputs[i] for i in param_idxs) + + def build_static_params(self, inputs): + """Return the statically-known dimensions of the ``size`` vector. + + ``size`` is the one input whose *value* (not type) determines the + outputs' static batch shape, so we fold it into the build. Returns a tuple + of ``int | None`` (``None`` per dimension that is not statically known), or + ``None`` when there is no size input or ``size`` itself is ``None``. + """ + (_, size_idx, _), _ = self.get_input_output_type_idxs(self.extended_signature) + if size_idx is None: + return None + size = inputs[size_idx] + if isinstance(size.type, NoneTypeT): + return None + size_len = get_vector_length(size) + _, static_shape = infer_static_shape([size[i] for i in range(size_len)]) + return tuple(static_shape) + + def build_inner_graph(self, rng, size, *dist_params) -> list[Variable]: + raise NotImplementedError + + def __init__(self, input_types=None, **kwargs): + # The size (and sometimes rng) input may be unused by the inner graph + kwargs.setdefault("on_unused_input", "ignore") + super().__init__(input_types=input_types, **kwargs) + if input_types is not None: + (_, out_idxs) = self.get_input_output_type_idxs(self.extended_signature)[1] + # Return the (single) non-rng output by default + if len(out_idxs) == 1: + self.default_output = out_idxs[0] + + def __call__( + self, *dist_params, size=None, rng=None, return_next_rng=False, **kwargs + ): + if rng is None: + from pytensor.tensor.random.variable import shared_rng + + rng = shared_rng(seed=None) + draws = super().__call__(rng, size, *dist_params, **kwargs) + if not return_next_rng: + return draws + (_, _, _), (rng_out_idxs, _) = self.get_input_output_type_idxs( + self.extended_signature + ) + [rng_out_idx] = rng_out_idxs + return draws.owner.outputs[rng_out_idx], draws + + def update(self, node: Apply) -> dict[Variable, Variable]: + (rng_in_idxs, _, _), (rng_out_idxs, _) = self.get_input_output_type_idxs( + self.extended_signature + ) + return { + node.inputs[i]: node.outputs[o] + for i, o in zip(rng_in_idxs, rng_out_idxs, strict=True) + } + + def batch_ndim(self, node: Apply) -> int: + out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs) + return cast(int, out_ndim - self.ndim_supp) + + class AbstractRNGConstructor(Op): def make_node(self, seed=None): if seed is None: diff --git a/pytensor/xtensor/rewriting/vectorization.py b/pytensor/xtensor/rewriting/vectorization.py index 824e675f1d..3cdb95d364 100644 --- a/pytensor/xtensor/rewriting/vectorization.py +++ b/pytensor/xtensor/rewriting/vectorization.py @@ -100,10 +100,11 @@ def lower_rv(fgraph, node): input_is_xrng = isinstance(rng.type, XRandomGeneratorType) tensor_rng = as_rng(rng) - # RVs are their own core Op - new_next_rng, tensor_out = core_op.make_node( - tensor_rng, size, *tensor_params - ).outputs + # RVs are their own core Op. Build via __call__ so SymbolicRVOp core ops + # (e.g. MvNormalRV) construct their inner graph lazily for these input types. + new_next_rng, tensor_out = core_op( + *tensor_params, size=size, rng=tensor_rng, return_next_rng=True + ) # Cast back to xtensor RNG if the input was xtensor if input_is_xrng: diff --git a/pytensor/xtensor/vectorization.py b/pytensor/xtensor/vectorization.py index fa4b058563..9f8dad67f8 100644 --- a/pytensor/xtensor/vectorization.py +++ b/pytensor/xtensor/vectorization.py @@ -315,7 +315,11 @@ def make_node(self, rng, *extra_dim_lengths_and_params): from pytensor.tensor.random.type import random_generator_type dummy_rng = random_generator_type() - core_node = self.core_op.make_node(dummy_rng, None, *dummy_core_inputs) + # Build a dummy core node via __call__ (works for both RandomVariable and + # SymbolicRVOp core ops; the latter builds its inner graph lazily). + core_node = self.core_op( + *dummy_core_inputs, rng=dummy_rng, return_next_rng=True + )[1].owner if not len(core_node.outputs) == 2: raise NotImplementedError( diff --git a/tests/benchmarks/test_random.py b/tests/benchmarks/test_random.py new file mode 100644 index 0000000000..c3eca0be59 --- /dev/null +++ b/tests/benchmarks/test_random.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor import function, shared + + +@pytest.mark.parametrize("n, size", [(50, 20_000)]) +def test_mvnormal_shared_cov_benchmark_numba(n, size, benchmark): + """MvNormal draws from a single shared covariance. + + Because ``MvNormalRV`` is a ``SymbolicRVOp`` (``mean + cholesky(cov) @ z``), + an unbatched ``cov`` is factorized once and reused across all ``size`` draws, + instead of re-factorizing it per draw inside the vectorized loop. + """ + rng = shared(np.random.default_rng(0)) + mean = pt.zeros(n) + cov = pt.tensor("cov", shape=(n, n)) + draws = pt.random.multivariate_normal(mean, cov, size=(size,), rng=rng) + + fn = function([cov], draws, mode="NUMBA", trust_input=True) + + test_rng = np.random.default_rng(1) + a = test_rng.standard_normal((n, n)) + cov_test = a @ a.T + n * np.eye(n) # symmetric positive-definite + + out = fn(cov_test) + assert out.shape == (size, n) + benchmark(fn, cov_test) diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index 6dd5cdbbe7..9a0cda3513 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -150,6 +150,31 @@ def test_multivariate_normal(): ) +@pytest.mark.parametrize("cov_dtype", ["float64", "float32", "int16"]) +@pytest.mark.parametrize("mu_dtype", ["float64", "float32", "int16"]) +def test_mvnormal_mixed_dtype_inputs(mu_dtype, cov_dtype): + """mean and cov dtypes are not enforced and can be integer or any float precision. + + The numba MvNormal must handle every combination (it previously crashed for any + non-float64 input: mixed-dtype ``np.dot`` and ``np.linalg.cholesky`` on integers). + """ + mean = np.array([1, 2, 3], dtype=mu_dtype) + cov = np.array( + [ + [4, 1, 0], + [1, 4, 0], + [0, 0, 4], + ], + dtype=cov_dtype, + ) + rng = shared(np.random.default_rng(675)) + draws = pt.random.multivariate_normal(mean, cov, size=(10_000,), rng=rng) + + draws_eval = draws.eval(mode="NUMBA") + np.testing.assert_allclose(np.mean(draws_eval, axis=0), mean, atol=0.1) + np.testing.assert_allclose(np.cov(draws_eval, rowvar=False), cov, atol=0.2) + + @pytest.mark.parametrize( "rv_op, dist_args, size", [ diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 358c95fc66..a4d42bfa99 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -637,13 +637,12 @@ def test_mvnormal_no_default_args(): def test_mvnormal_impl_catches_incompatible_size(): - with pytest.raises(ValueError, match="operands could not be broadcast together "): - multivariate_normal.rng_fn( - np.random.default_rng(), + with pytest.raises((ValueError, AssertionError)): + multivariate_normal( np.zeros((3, 2)), np.broadcast_to(np.eye(2), (3, 2, 2)), size=(4,), - ) + ).eval() def test_mvnormal_ShapeFeature(): @@ -711,14 +710,11 @@ def test_mvnormal_cov_decomposition_method(method, psd): draws = multivariate_normal(mean, cov, method=method, size=(10_000,), rng=rng) assert draws.owner.op.method == method - # JAX doesn't raise errors at runtime if not psd and method == "cholesky": - if mode == "JAX": - # JAX doesn't raise errors at runtime, instead it returns nan - np.isnan(draws.eval(mode=mode)).all() - else: - with pytest.raises(np.linalg.LinAlgError): - draws.eval(mode=mode) + # The decomposed MvNormal uses the Cholesky Op, which returns nan + # (rather than raising) for non-positive-definite inputs across all + # backends, so the draws propagate nan. + assert np.isnan(draws.eval(mode=mode)).all() else: draws_eval = draws.eval(mode=mode)