Ensure distribution moments/samples are jax arrays#2206
Open
tillahoffmann wants to merge 7 commits into
Open
Conversation
Coerce stored parameters to jax.Array in Distribution.__init__ so that moments, samples, and log_prob consistently return jax arrays rather than python scalars or numpy arrays. Real-valued (non-discrete) integer parameters are promoted to float so Poisson(2) behaves identically to Poisson(2.0) (pyro-pplgh-2181); the now-redundant per-method rate cast in Poisson.log_prob is dropped. Normalization runs *after* argument validation, which relies on constant parameters folding under jit. It walks the MRO so reparametrizing subclasses (e.g. Chi2, NegativeBinomialProbs) normalize inherited parameters, and skips pytree aux fields (e.g. total_count, adj_matrix) and parameters that are not array-coercible (e.g. a scipy-sparse adj_matrix). Because parameters are now (traced) jax arrays, parameter-dependent constraint bounds can be jax arrays too. Fix Constraint dispatch to select jax.numpy whenever any operand -- the checked value or a bound -- is a jax array, so constraints no longer raise under jit while keeping the fast eager numpy path. Argument/sample validation is consequently an eager-only operation (it cannot branch on traced masks under jit); the corresponding omnistaging test assertions are updated to reflect this. Add test_outputs_are_arrays asserting that samples, moments, and stored parameters are jax arrays and that real-valued parameters are floating point. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Flip return-position annotations from ArrayLike to Array across the distribution modules (mean, variance, entropy, sample, log_prob, cdf, icdf, enumerate_support, ...). Input parameters and method arguments keep ArrayLike (be liberal in what we accept); only outputs change, since these methods always return concrete jax arrays -- guaranteed by the parameter normalization in Distribution.__init__. ArrayLike is a wide input union (Array | ndarray | bool | int | float | complex | ...); annotating it on outputs forced consumers to defend against its scalar members, e.g. `dist.Normal(...).mean[i]`, `.variance.shape`, or passing a moment into a function typed `(x: Array)` were all type errors. Also add the Array import to conjugate/truncated/censored/flows and update stale `:rtype: ArrayLike` docstrings to `:rtype: Array`. `_validate_sample` deliberately keeps an ArrayLike return: its mask must not be coerced to a jax array, since a concrete numpy mask coerced under jit becomes a tracer and would defeat the `not_jax_tracer` check that gates the out-of-support warning. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…type Use functional jax ops on `value` arguments where only shape or a reduction is needed (`jnp.shape(value)` instead of `value.shape`, `jnp.sum(value, -1)` instead of `value.sum(-1)`), which work for any ArrayLike rather than assuming an array. Coerce with `jnp.asarray(value)` only where the body genuinely indexes the value (GaussianRandomWalk, SineBivariateVonMises, TruncatedPolyaGamma). Also fix IntervalCensoredDistribution._validate_sample, which was annotated `-> None` despite returning a mask (and despite the base contract returning ArrayLike that the validate_sample wrapper uses for masking); annotate it `-> ArrayLike` to match. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ger-only Replace the construction-time `_normalize_args` pass with explicit coercion at parameter-assignment sites, so stored parameters are statically typed as `jax.Array` (a `setattr` loop is invisible to the type checker). `promote_shapes` gains an opt-in `promote_array=True` (typed to return `list[Array]`) used at the assignment sites; the few directly-assigned parameters are wrapped in `jnp.asarray`. `promote_shapes` is otherwise unchanged and type-preserving, so existing callers are unaffected. Pytree *aux* fields are deliberately left un-coerced: `total_count` on the multinomial family must stay concrete so that constructing a Multinomial inside `jit` keeps it a Python int (the sampler needs `int(np.max(total_count))`), and a scipy-sparse `adj_matrix` on `CAR` must not be arrayified. Two tests guard this: `test_multinomial_total_count_static_under_jit` (the concrete failure mode) and `test_aux_fields_are_not_jax_arrays` (the general invariant across all distributions). Because parameters are now arrays before `validate_args` runs, argument validation is an eager-only operation: an out-of-bounds *constant* parameter inside `jit` no longer raises (it previously did by constant-folding). This was a narrow, hard-to-reason-about guarantee -- it fired only for literal constants, never for traced parameters, and is inconsistent with sample validation, which is already gated on concreteness. Annotate `DistributionMeta.__call__` as returning `Distribution` so that the `-> Array` output annotations are reachable through the metaclass (instances are no longer typed `Any`). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- Fix SineBivariateVonMises.__init__: `assert correlation` evaluated the truth value of a (possibly multi-element) array and raised; use `assert correlation is not None`. This was caught by the new test_aux_fields_are_not_jax_arrays, added after the previous full test run and so slipped into the prior commit untested. - Optimize `promote_shapes(..., promote_array=True)` to skip `jnp.asarray` for arguments that are already jax arrays (a cheap isinstance check), so coercion is near-free when parameters are already arrays -- the typical case in traced / performance-sensitive code. - Type DistributionMeta.__call__ to return the concrete class being constructed, so e.g. `Normal(...)` is typed `Normal` (not `Any` or the base `Distribution`) and the `-> Array` method annotations are reachable by type checkers. mypy rejects the (correct) metaclass self-type and super() call as unsound (known gaps: mypy#3625, mypy#11678); ty/pyright accept them, so the ignores are mypy- only. - Rename test_outputs_are_arrays -> test_params_and_outputs_are_arrays (it also checks stored parameters), and drop the redundant test_multinomial_total_count_static_under_jit (subsumed by the general test_aux_fields_are_not_jax_arrays). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ee74d90 to
2e2dcc1
Compare
…leaks The parameter coercion in this branch turned several discrete-distribution parameters into jax arrays. For parameters that `enumerate_support` reads concretely (it builds the support range with `jnp.arange`), that makes them tracers under enumeration / multi-chain `vmap`, raising `ConcretizationTypeError` (surfaced by `test_discrete_gibbs_multiple_sites_chain` and the `annotation.py --model mace` example). Keep those parameters concrete: - `BinomialProbs`/`BinomialLogits`: do not coerce `total_count` (coerce `probs`/ `logits` only). - `BetaBinomial`: do not coerce `total_count` (it reuses `BinomialProbs.enumerate_support`); coerce the concentrations only. - `DiscreteUniform`: do not coerce `low`/`high`; coerce the moment outputs (`mean`/`variance`) instead so they remain jax arrays. `test_enumerate_support_under_vmap` guards this class of regression directly by enumerating under `vmap`. `test_outputs_are_arrays` now checks method outputs only (samples, log_prob, mean, variance, entropy), not stored parameters: the array-ness we promise is on outputs, and some parameters are intentionally kept concrete. `test_distributions_mixture.py` builds its component distributions lazily so no jax arrays are created at import time (which would trip the `jax.live_arrays()` guard in conftest, since parameters are now coerced at construction). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
`test_gof.py` constructs distributions from the shared `CONTINUOUS`/`DIRECTIONAL` parameter lists, which now wrap some base distributions (EulerMaruyama's and TwoSidedTruncatedDistribution's) in `LazyDist` to avoid creating jax arrays at import time. Resolve those to concrete distributions via `_resolve_params` before construction, as the other tests do; otherwise the raw `LazyDist` reaches the constructor and fails its `isinstance(..., Distribution)` check. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Collaborator
|
Hi @tillahoffmann ! Coult we try adding these modified modules in https://github.com/pyro-ppl/numpyro/blob/master/pyproject.toml#L100 so that mypy runs the type checks? If we get many errors, we can tackle them in smaller PRs 🤞 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR makes numpyro distributions return concrete
jax.Arrayfrom their public methods (mean,variance,entropy,sample,log_prob,cdf,icdf, …) and annotates them accordingly, so that typed downstream code can actually use distribution outputs. To make that true at runtime, parameters are coerced tojax.Arrayat construction.The change aligns numpyro with JAX's own typing best practices — wide on input (
ArrayLike), strict on output (Array) — and removes a class of type errors that typed consumers will hit once they rely on numpyro's types.The near-term, standalone win is the parameter coercion: today whether a stored parameter (
self.loc,self.rate, …) is ajax.Array, a numpy array, or a python scalar depends on the input — this PR makes it consistently ajax.Array. The output annotations are the cheap, correct-by-construction follow-through; they become observable to consumers once numpyro shipspy.typed(a deliberate follow-up — see below).This is necessarily runtime work, not just an annotation change: many moments return a stored parameter directly (
Poisson.meanisself.rate), so anArrayreturn type is only honest if that parameter is actually an array — typing the outputs without making the runtime match would just lie to the type checker.It also changes one behavior deliberately: argument validation (
validate_args=True) is now an eager-only operation. That is the part most worth discussing, so it is called out explicitly below.Motivation: the output annotations are wrong, and it bites consumers
Public methods were annotated to return
jax.typing.ArrayLike:ArrayLikeis a deliberately wide input union. JAX's typing guidance is explicit that it is the right type for function inputs (be liberal in what you accept) and that outputs should be annotatedArray(be strict in what you return). numpyro appliedArrayLiketo outputs too, which is incorrect:mean/sample/log_prob/… (should) always produce a concrete array.The cost lands on every typed consumer. The scalar members of the union (
bool/int/float/complex) support neither indexing,.shape,.reshape, nor assignment to anArrayparameter, so ordinary usage fails type checking:Under pyright, every one of those lines is a type error when the method returns
ArrayLike(each indexing/attribute access fails once per scalar member of the union —int,float,complex, …), and all of them type-check cleanly once it returnsArray.This is currently latent (numpyro ships no
py.typed, so checkers treat it as untypedAny), but it is a trap: the moment a consumer relies on numpyro's types — or numpyro starts enforcing its own — these annotations break real code.The benefit here is for downstream typed code. This PR does not aim to make numpyro's own modules type-clean: they are not mypy-enforced, and their residual errors are dominated by pre-existing
[override]signature mismatches unrelated to these annotations (enabling enforcement is called out as out of scope below).What this PR does
jax.Arrayat construction. Previously parameter storage was inconsistent: most parameters flow throughpromote_shapes, which preserved input types and only produced arrays when a broadcast forced a reshape, while some distributions stored the raw argument directly — so whetherself.locended up an array depended on the input shapes and dtype. Nowpromote_shapestakes an opt-inpromote_array=True(typed to returnlist[Array]) used at the parameter-assignment sites, and the few directly-assigned parameters are wrapped injnp.asarray. As a resultself.loc,self.scale, … are statically and dynamicallyjax.Array.This is runtime work, not just annotation, on purpose: annotating the outputs as
Arraywithout actually returning arrays would be a lie to the type checker — many moments simply return a stored parameter (Poisson.meanisself.rate), so the annotation is only honest if the parameter is genuinely an array. Typing that isn't backed by the runtime is worse than no typing.Annotate outputs as
Array. Return-positionArrayLikebecomesArrayacross the distribution modules. Constructor parameters and method arguments (value,q) keepArrayLike, so callers can still pass python scalars / numpy arrays.Make instance types reachable through the metaclass.
DistributionMeta.__call__is annotated to return the concrete class being constructed, sodist.Normal(...)is typedNormal(notAny, and not the baseDistribution) and the-> Arrayannotations on its methods are visible to consumers.Use functional array ops on inputs where appropriate. Where a method only needed a shape it now uses
jnp.shape(value)rather thanvalue.shape, which is correct for anyArrayLike(and incidentally more robust). A few methods that genuinely indexvaluecoerce it once.Parameters that must remain non-array are preserved: pytree aux fields (
total_counton the multinomial family,adj_matrixonCAR) are static metadata — coercing them to ajax.Arraywould make them tracers underjit/vmapand break sampling — so they are left un-coerced, as is a scipy-sparseadj_matrix(CAR(is_sparse=True)).The coercion touches a parameter-assignment site in essentially every distribution, so correctness rests on tests that run across all distributions rather than on per-site review:
test_params_and_outputs_are_arrays— samples, moments, and stored (non-aux) parameters arejax.Array. This catches any data parameter the sweep missed.test_aux_fields_are_not_jax_arrays— no pytree aux field is ever ajax.Array. This catches the opposite mistake: an aux field (e.g.total_count) wrongly coerced, which would breakjit/vmap(for instance, a jittedMultinomial(...).sample(...)needs a concretetotal_count).The deliberate behavior change: validation is eager-only
Because parameters are coerced to arrays before
Distribution.__init__runsvalidate_args, an out-of-bounds constant parameter insidejitno longer raises. Previously (#775) it did, by keeping parameters as concrete (non-jax) values so the validity check could constant-fold. #775 added this so that a typo'd literal —Normal(0., -1.)in a jitted model — would fail loudly at trace time rather than silently produce NaNs. That is a real concern, so this is worth making the case for rather than glossing over.The guarantee it provided was always narrow. It fired only for parameters that were compile-time constants; the moment the same bad value arrives through a traced argument (
Normal(0., scale)withscale = -1.), validation cannot run — you can't branch on a tracer — and never has. So onmastertoday, with the samevalidate_args=Trueand the same invalid value, whether you get an error depends on how the value was staged:A safety net that catches
-1.0written as a literal but not-1.0passed as an argument is not one a user can rely on without knowing how their parameters get staged.The new behavior is simpler to state:
validate_args=Truevalidates eagerly and does nothing underjit. Eager validation is unchanged — invalid parameters raise exactly as before outsidejit, which is where it is most useful for debugging (construct the distribution eagerly, or run once withoutjit, to surface the error).This also lines up with how out-of-support sample validation already behaves. That check warns only when the support mask is concrete, so it too goes silent under
jitwhenever the sampled value is traced — which is the usual case, since in a model/MCMC/SVI the value fed tolog_probis the (traced) data. The only situation where it still fires underjitis a constant-support distribution with a constant value, the same accidental-constant-folding corner that parameter validation used to rely on. So under realistic jitted usage both parameter and sample validation are effectively eager-only.Performance
Coercion is a
jnp.asarrayper parameter, but it is skipped for arguments that are alreadyjax.Array(a cheapisinstancecheck instead of a dispatch), so the cost falls almost entirely on constructing from python scalars / numpy. Eager-construction microbenchmarks (min of repeated runs):Normal(0., 1.)(python scalars)Normal(jnp.asarray(0.), jnp.asarray(1.))Normal(jnp.zeros(100), 1.)MultivariateNormal(jnp.zeros(5), jnp.eye(5))So constructing from jax arrays — the typical case in real code, and what happens under tracing — is unaffected. Only constructing from python-scalar literals pays the full conversion (
Normal(0., 1.)≈5×), and even then it is tens of microseconds. The jitted MCMC/SVI path is unchanged: the model and the distributions it builds are traced, so the coercion folds away (a jitted log-density constructingNormal+Betais ~3.1 µs/call on both branches, with equal compile time).Explicitly not in this PR
Enabling mypy enforcement for the distribution modules: the output annotations are now correct, but fully type-clean modules also require resolving pre-existing
[override]signature mismatches andOptionalhandling, which are out of scope here.Widening constructor inputs to
ArrayLike— see Next steps.Next steps
This PR fixes the output side of the convention (strict
Arrayout). The input side is a natural follow-up: a number of constructor parameters are still annotated with the strictArray(e.g.CategoricalProbs.probs,Dirichlet.concentration,MatrixNormal.loc/scale_tril_*,LowRankMultivariateNormal.*,Distribution.mask), so passing a numpy array or a python scalar to them is a type error even though it works at runtime. These should be widened toArrayLike(be liberal in what you accept).Widening the annotations is the easy part; the work is that it surfaces constructor bodies that access
.ndim/.shape/.reshape/etc. on a parameter before it is coerced to an array (those attributes are not defined on the scalar members ofArrayLike). So the follow-up is "widen the input annotations and coerce the parameters at the top of each affected constructor", done together. It is kept separate here to keep this PR focused on outputs and to give the input change its own review/CI cycle.(
CAR.adj_matrixis a deliberate exception — it also accepts ascipy.sparsematrix, so it needs a union rather than plainArrayLike.)Testing
test/test_distributions.pypasses in full, including the newtest_outputs_are_arraysandtest_aux_fields_are_not_jax_arrays. The eager-only argument/sample validation behavior is covered by the existing validation tests, updated to reflect the new semantics.