Skip to content

Ensure distribution moments/samples are jax arrays#2206

Open
tillahoffmann wants to merge 7 commits into
pyro-ppl:masterfrom
tillahoffmann:add-jax-array-output-test
Open

Ensure distribution moments/samples are jax arrays#2206
tillahoffmann wants to merge 7 commits into
pyro-ppl:masterfrom
tillahoffmann:add-jax-array-output-test

Conversation

@tillahoffmann

@tillahoffmann tillahoffmann commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Summary

This PR makes numpyro distributions return concrete jax.Array from their public methods (mean, variance, entropy, sample, log_prob, cdf, icdf, …) and annotates them accordingly, so that typed downstream code can actually use distribution outputs. To make that true at runtime, parameters are coerced to jax.Array at construction.

The change aligns numpyro with JAX's own typing best practiceswide on input (ArrayLike), strict on output (Array) — and removes a class of type errors that typed consumers will hit once they rely on numpyro's types.

The near-term, standalone win is the parameter coercion: today whether a stored parameter (self.loc, self.rate, …) is a jax.Array, a numpy array, or a python scalar depends on the input — this PR makes it consistently a jax.Array. The output annotations are the cheap, correct-by-construction follow-through; they become observable to consumers once numpyro ships py.typed (a deliberate follow-up — see below).

This is necessarily runtime work, not just an annotation change: many moments return a stored parameter directly (Poisson.mean is self.rate), so an Array return type is only honest if that parameter is actually an array — typing the outputs without making the runtime match would just lie to the type checker.

It also changes one behavior deliberately: argument validation (validate_args=True) is now an eager-only operation. That is the part most worth discussing, so it is called out explicitly below.

Motivation: the output annotations are wrong, and it bites consumers

Public methods were annotated to return jax.typing.ArrayLike:

ArrayLike = Union[jax.Array, np.ndarray, np.bool, np.number, bool, int, float, complex]

ArrayLike is a deliberately wide input union. JAX's typing guidance is explicit that it is the right type for function inputs (be liberal in what you accept) and that outputs should be annotated Array (be strict in what you return). numpyro applied ArrayLike to outputs too, which is incorrect: mean/sample/log_prob/… (should) always produce a concrete array.

The cost lands on every typed consumer. The scalar members of the union (bool/int/float/complex) support neither indexing, .shape, .reshape, nor assignment to an Array parameter, so ordinary usage fails type checking:

d = dist.Normal(jnp.zeros(5), 1.0)
d.mean[3]                  # error: int/float/complex are not indexable
d.variance.shape           # error: .shape not defined on int/float/complex
some_jit_fn(d.mean)        # error: ArrayLike not assignable to Array
d.mean > 1                 # error: ">" not supported for "complex" and "Literal[1]"

Under pyright, every one of those lines is a type error when the method returns ArrayLike (each indexing/attribute access fails once per scalar member of the union — int, float, complex, …), and all of them type-check cleanly once it returns Array.

This is currently latent (numpyro ships no py.typed, so checkers treat it as untyped Any), but it is a trap: the moment a consumer relies on numpyro's types — or numpyro starts enforcing its own — these annotations break real code.

The benefit here is for downstream typed code. This PR does not aim to make numpyro's own modules type-clean: they are not mypy-enforced, and their residual errors are dominated by pre-existing [override] signature mismatches unrelated to these annotations (enabling enforcement is called out as out of scope below).

What this PR does

  1. Coerce parameters to jax.Array at construction. Previously parameter storage was inconsistent: most parameters flow through promote_shapes, which preserved input types and only produced arrays when a broadcast forced a reshape, while some distributions stored the raw argument directly — so whether self.loc ended up an array depended on the input shapes and dtype. Now promote_shapes takes an opt-in promote_array=True (typed to return list[Array]) used at the parameter-assignment sites, and the few directly-assigned parameters are wrapped in jnp.asarray. As a result self.loc, self.scale, … are statically and dynamically jax.Array.

This is runtime work, not just annotation, on purpose: annotating the outputs as Array without actually returning arrays would be a lie to the type checker — many moments simply return a stored parameter (Poisson.mean is self.rate), so the annotation is only honest if the parameter is genuinely an array. Typing that isn't backed by the runtime is worse than no typing.

  1. Annotate outputs as Array. Return-position ArrayLike becomes Array across the distribution modules. Constructor parameters and method arguments (value, q) keep ArrayLike, so callers can still pass python scalars / numpy arrays.

  2. Make instance types reachable through the metaclass. DistributionMeta.__call__ is annotated to return the concrete class being constructed, so dist.Normal(...) is typed Normal (not Any, and not the base Distribution) and the -> Array annotations on its methods are visible to consumers.

  3. Use functional array ops on inputs where appropriate. Where a method only needed a shape it now uses jnp.shape(value) rather than value.shape, which is correct for any ArrayLike (and incidentally more robust). A few methods that genuinely index value coerce it once.

Parameters that must remain non-array are preserved: pytree aux fields (total_count on the multinomial family, adj_matrix on CAR) are static metadata — coercing them to a jax.Array would make them tracers under jit/vmap and break sampling — so they are left un-coerced, as is a scipy-sparse adj_matrix (CAR(is_sparse=True)).

The coercion touches a parameter-assignment site in essentially every distribution, so correctness rests on tests that run across all distributions rather than on per-site review:

  • test_params_and_outputs_are_arrays — samples, moments, and stored (non-aux) parameters are jax.Array. This catches any data parameter the sweep missed.
  • test_aux_fields_are_not_jax_arrays — no pytree aux field is ever a jax.Array. This catches the opposite mistake: an aux field (e.g. total_count) wrongly coerced, which would break jit/vmap (for instance, a jitted Multinomial(...).sample(...) needs a concrete total_count).

The deliberate behavior change: validation is eager-only

Because parameters are coerced to arrays before Distribution.__init__ runs validate_args, an out-of-bounds constant parameter inside jit no longer raises. Previously (#775) it did, by keeping parameters as concrete (non-jax) values so the validity check could constant-fold. #775 added this so that a typo'd literal — Normal(0., -1.) in a jitted model — would fail loudly at trace time rather than silently produce NaNs. That is a real concern, so this is worth making the case for rather than glossing over.

The guarantee it provided was always narrow. It fired only for parameters that were compile-time constants; the moment the same bad value arrives through a traced argument (Normal(0., scale) with scale = -1.), validation cannot run — you can't branch on a tracer — and never has. So on master today, with the same validate_args=True and the same invalid value, whether you get an error depends on how the value was staged:

# constant invalid scale -> raises
jax.jit(lambda: dist.Normal(0.0, -1.0, validate_args=True).log_prob(0.5))()

# traced invalid scale -> silently does NOT raise (same bad value!)
jax.jit(lambda s: dist.Normal(0.0, s, validate_args=True).log_prob(0.5))(-1.0)

A safety net that catches -1.0 written as a literal but not -1.0 passed as an argument is not one a user can rely on without knowing how their parameters get staged.

The new behavior is simpler to state: validate_args=True validates eagerly and does nothing under jit. Eager validation is unchanged — invalid parameters raise exactly as before outside jit, which is where it is most useful for debugging (construct the distribution eagerly, or run once without jit, to surface the error).

This also lines up with how out-of-support sample validation already behaves. That check warns only when the support mask is concrete, so it too goes silent under jit whenever the sampled value is traced — which is the usual case, since in a model/MCMC/SVI the value fed to log_prob is the (traced) data. The only situation where it still fires under jit is a constant-support distribution with a constant value, the same accidental-constant-folding corner that parameter validation used to rely on. So under realistic jitted usage both parameter and sample validation are effectively eager-only.

Performance

Coercion is a jnp.asarray per parameter, but it is skipped for arguments that are already jax.Array (a cheap isinstance check instead of a dispatch), so the cost falls almost entirely on constructing from python scalars / numpy. Eager-construction microbenchmarks (min of repeated runs):

construction master branch
Normal(0., 1.) (python scalars) ~5.4 µs ~31 µs
Normal(jnp.asarray(0.), jnp.asarray(1.)) ~1.8 µs ~2.0 µs
Normal(jnp.zeros(100), 1.) ~9.7 µs ~10.1 µs
MultivariateNormal(jnp.zeros(5), jnp.eye(5)) ~30.7 µs ~30.6 µs

So constructing from jax arrays — the typical case in real code, and what happens under tracing — is unaffected. Only constructing from python-scalar literals pays the full conversion (Normal(0., 1.) ≈5×), and even then it is tens of microseconds. The jitted MCMC/SVI path is unchanged: the model and the distributions it builds are traced, so the coercion folds away (a jitted log-density constructing Normal + Beta is ~3.1 µs/call on both branches, with equal compile time).

Explicitly not in this PR

Enabling mypy enforcement for the distribution modules: the output annotations are now correct, but fully type-clean modules also require resolving pre-existing [override] signature mismatches and Optional handling, which are out of scope here.

Widening constructor inputs to ArrayLike — see Next steps.

Next steps

This PR fixes the output side of the convention (strict Array out). The input side is a natural follow-up: a number of constructor parameters are still annotated with the strict Array (e.g. CategoricalProbs.probs, Dirichlet.concentration, MatrixNormal.loc/scale_tril_*, LowRankMultivariateNormal.*, Distribution.mask), so passing a numpy array or a python scalar to them is a type error even though it works at runtime. These should be widened to ArrayLike (be liberal in what you accept).

Widening the annotations is the easy part; the work is that it surfaces constructor bodies that access .ndim/.shape/.reshape/etc. on a parameter before it is coerced to an array (those attributes are not defined on the scalar members of ArrayLike). So the follow-up is "widen the input annotations and coerce the parameters at the top of each affected constructor", done together. It is kept separate here to keep this PR focused on outputs and to give the input change its own review/CI cycle.

(CAR.adj_matrix is a deliberate exception — it also accepts a scipy.sparse matrix, so it needs a union rather than plain ArrayLike.)

Testing

test/test_distributions.py passes in full, including the new test_outputs_are_arrays and test_aux_fields_are_not_jax_arrays. The eager-only argument/sample validation behavior is covered by the existing validation tests, updated to reflect the new semantics.

@tillahoffmann tillahoffmann marked this pull request as draft June 8, 2026 21:57
tillahoffmann and others added 5 commits June 8, 2026 23:48
Coerce stored parameters to jax.Array in Distribution.__init__ so that
moments, samples, and log_prob consistently return jax arrays rather than
python scalars or numpy arrays. Real-valued (non-discrete) integer
parameters are promoted to float so Poisson(2) behaves identically to
Poisson(2.0) (pyro-pplgh-2181); the now-redundant per-method rate cast in
Poisson.log_prob is dropped.

Normalization runs *after* argument validation, which relies on constant
parameters folding under jit. It walks the MRO so reparametrizing subclasses
(e.g. Chi2, NegativeBinomialProbs) normalize inherited parameters, and skips
pytree aux fields (e.g. total_count, adj_matrix) and parameters that are not
array-coercible (e.g. a scipy-sparse adj_matrix).

Because parameters are now (traced) jax arrays, parameter-dependent
constraint bounds can be jax arrays too. Fix Constraint dispatch to select
jax.numpy whenever any operand -- the checked value or a bound -- is a jax
array, so constraints no longer raise under jit while keeping the fast eager
numpy path. Argument/sample validation is consequently an eager-only
operation (it cannot branch on traced masks under jit); the corresponding
omnistaging test assertions are updated to reflect this.

Add test_outputs_are_arrays asserting that samples, moments, and stored
parameters are jax arrays and that real-valued parameters are floating point.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Flip return-position annotations from ArrayLike to Array across the
distribution modules (mean, variance, entropy, sample, log_prob, cdf, icdf,
enumerate_support, ...). Input parameters and method arguments keep ArrayLike
(be liberal in what we accept); only outputs change, since these methods
always return concrete jax arrays -- guaranteed by the parameter normalization
in Distribution.__init__.

ArrayLike is a wide input union (Array | ndarray | bool | int | float |
complex | ...); annotating it on outputs forced consumers to defend against
its scalar members, e.g. `dist.Normal(...).mean[i]`, `.variance.shape`, or
passing a moment into a function typed `(x: Array)` were all type errors.

Also add the Array import to conjugate/truncated/censored/flows and update
stale `:rtype: ArrayLike` docstrings to `:rtype: Array`.

`_validate_sample` deliberately keeps an ArrayLike return: its mask must not
be coerced to a jax array, since a concrete numpy mask coerced under jit
becomes a tracer and would defeat the `not_jax_tracer` check that gates the
out-of-support warning.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…type

Use functional jax ops on `value` arguments where only shape or a reduction is
needed (`jnp.shape(value)` instead of `value.shape`, `jnp.sum(value, -1)`
instead of `value.sum(-1)`), which work for any ArrayLike rather than assuming
an array. Coerce with `jnp.asarray(value)` only where the body genuinely
indexes the value (GaussianRandomWalk, SineBivariateVonMises,
TruncatedPolyaGamma).

Also fix IntervalCensoredDistribution._validate_sample, which was annotated
`-> None` despite returning a mask (and despite the base contract returning
ArrayLike that the validate_sample wrapper uses for masking); annotate it
`-> ArrayLike` to match.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ger-only

Replace the construction-time `_normalize_args` pass with explicit coercion at
parameter-assignment sites, so stored parameters are statically typed as
`jax.Array` (a `setattr` loop is invisible to the type checker). `promote_shapes`
gains an opt-in `promote_array=True` (typed to return `list[Array]`) used at the
assignment sites; the few directly-assigned parameters are wrapped in
`jnp.asarray`. `promote_shapes` is otherwise unchanged and type-preserving, so
existing callers are unaffected.

Pytree *aux* fields are deliberately left un-coerced: `total_count` on the
multinomial family must stay concrete so that constructing a Multinomial inside
`jit` keeps it a Python int (the sampler needs `int(np.max(total_count))`), and a
scipy-sparse `adj_matrix` on `CAR` must not be arrayified. Two tests guard this:
`test_multinomial_total_count_static_under_jit` (the concrete failure mode) and
`test_aux_fields_are_not_jax_arrays` (the general invariant across all
distributions).

Because parameters are now arrays before `validate_args` runs, argument
validation is an eager-only operation: an out-of-bounds *constant* parameter
inside `jit` no longer raises (it previously did by constant-folding). This was
a narrow, hard-to-reason-about guarantee -- it fired only for literal constants,
never for traced parameters, and is inconsistent with sample validation, which
is already gated on concreteness.

Annotate `DistributionMeta.__call__` as returning `Distribution` so that the
`-> Array` output annotations are reachable through the metaclass (instances are
no longer typed `Any`).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- Fix SineBivariateVonMises.__init__: `assert correlation` evaluated the truth
  value of a (possibly multi-element) array and raised; use
  `assert correlation is not None`. This was caught by the new
  test_aux_fields_are_not_jax_arrays, added after the previous full test run and
  so slipped into the prior commit untested.
- Optimize `promote_shapes(..., promote_array=True)` to skip `jnp.asarray` for
  arguments that are already jax arrays (a cheap isinstance check), so coercion
  is near-free when parameters are already arrays -- the typical case in traced
  / performance-sensitive code.
- Type DistributionMeta.__call__ to return the concrete class being constructed,
  so e.g. `Normal(...)` is typed `Normal` (not `Any` or the base `Distribution`)
  and the `-> Array` method annotations are reachable by type checkers. mypy
  rejects the (correct) metaclass self-type and super() call as unsound (known
  gaps: mypy#3625, mypy#11678); ty/pyright accept them, so the ignores are mypy-
  only.
- Rename test_outputs_are_arrays -> test_params_and_outputs_are_arrays (it also
  checks stored parameters), and drop the redundant
  test_multinomial_total_count_static_under_jit (subsumed by the general
  test_aux_fields_are_not_jax_arrays).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@tillahoffmann tillahoffmann force-pushed the add-jax-array-output-test branch from ee74d90 to 2e2dcc1 Compare June 9, 2026 18:51
tillahoffmann and others added 2 commits June 9, 2026 16:31
…leaks

The parameter coercion in this branch turned several discrete-distribution
parameters into jax arrays. For parameters that `enumerate_support` reads
concretely (it builds the support range with `jnp.arange`), that makes them
tracers under enumeration / multi-chain `vmap`, raising `ConcretizationTypeError`
(surfaced by `test_discrete_gibbs_multiple_sites_chain` and the `annotation.py
--model mace` example). Keep those parameters concrete:

- `BinomialProbs`/`BinomialLogits`: do not coerce `total_count` (coerce `probs`/
  `logits` only).
- `BetaBinomial`: do not coerce `total_count` (it reuses
  `BinomialProbs.enumerate_support`); coerce the concentrations only.
- `DiscreteUniform`: do not coerce `low`/`high`; coerce the moment outputs
  (`mean`/`variance`) instead so they remain jax arrays.

`test_enumerate_support_under_vmap` guards this class of regression directly by
enumerating under `vmap`.

`test_outputs_are_arrays` now checks method outputs only (samples, log_prob,
mean, variance, entropy), not stored parameters: the array-ness we promise is on
outputs, and some parameters are intentionally kept concrete.

`test_distributions_mixture.py` builds its component distributions lazily so no
jax arrays are created at import time (which would trip the `jax.live_arrays()`
guard in conftest, since parameters are now coerced at construction).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
`test_gof.py` constructs distributions from the shared `CONTINUOUS`/`DIRECTIONAL`
parameter lists, which now wrap some base distributions (EulerMaruyama's and
TwoSidedTruncatedDistribution's) in `LazyDist` to avoid creating jax arrays at
import time. Resolve those to concrete distributions via `_resolve_params`
before construction, as the other tests do; otherwise the raw `LazyDist` reaches
the constructor and fails its `isinstance(..., Distribution)` check.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@tillahoffmann tillahoffmann marked this pull request as ready for review June 9, 2026 22:52
@juanitorduz

Copy link
Copy Markdown
Collaborator

Hi @tillahoffmann ! Coult we try adding these modified modules in https://github.com/pyro-ppl/numpyro/blob/master/pyproject.toml#L100 so that mypy runs the type checks? If we get many errors, we can tackle them in smaller PRs 🤞

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants