docs: add migration guides and tutorial#999
Conversation
ev-br
left a comment
There was a problem hiding this comment.
This is very nice.
I've a left several comments, mostly very minor. There are two themes:
- for the migration guide, I think it would be helpful to be some more specific and some more opinionated.
- In the tutorial, I love the power iteration example. The only concern is that you talk about JIT, but the example does not jit, not easily at least. So maybe either expand a bit on how to actually jit it, or state that not everything is benefits from jitting and add one more example which does?
spec/2025.12/migration_guide.md
Outdated
| `array-api-strict` is a library that provides a strict and minimal | ||
| implementation of the Array API Standard. It is designed to be used as | ||
| a reference implementation for testing and development purposes. By comparing | ||
| your API calls with `array-api-strict` counterparts, you can ensure that your |
There was a problem hiding this comment.
This suggestion reads as aimed at array API producers. Consider expanding to mention that an array API consumer can run their test suite with array-api-strict as a producer.
As a larger edit, consider briefly stressing in these short descriptions that array-api-compat is mainly for consumer's code, array-api-strict is a testing implement, also mostly for a consumer; array-api-tests is for a producer. While you are discussing this in more detail below, I think it'd be helpful to stress it in the very beginning, too.
There was a problem hiding this comment.
Done - I added a User group clause to each library and mentioned how producers/consumers use it.
| 3. Inside the `array-api-tests` directory run `pytest` command. There are | ||
| multiple useful options delivered by the test suite, a few worth mentioning: | ||
| - `--max-examples=2` - maximal number of test cases to generate by the | ||
| hypothesis. This allows you to balance between execution time of the test |
There was a problem hiding this comment.
Minor: add a link to https://hypothesis.readthedocs.io/en/latest/ + a couple of sentensies to explicitly state that hypothesis generates this many examples, where each example is a valid combination of inputs; and that the recommendation is to use as many examples as reasonable for the time budget.
spec/2025.12/migration_guide.md
Outdated
| - `-o xfail_strict=<bool>` is often used with the previous one. If a test | ||
| expected to fail actually passes (`XPASS`) then you can decide whether | ||
| to ignore that fact or raise it as an error. | ||
| - `--skips-file` for skipping files. At times some failing tests might stall |
There was a problem hiding this comment.
| - `--skips-file` for skipping files. At times some failing tests might stall | |
| - `--skips-file` for skipping tests. At times some failing tests might stall |
spec/2025.12/migration_guide.md
Outdated
|
|
||
| - If you are building a library where the backend is determined by input arrays | ||
| passed by the end-user, then a recommended way is to ask your input arrays for a | ||
| namespace to use: `xp = arr.__array_namespace__()` |
There was a problem hiding this comment.
For consumers, the support is less than ideal, consider mentioning array-api-compat and array_namespace function already?
In [9]: for xp in [np, torch, jnp, cupy, da, array_api_strict]:
...: try:
...: xp.ones(3).__array_namespace__()
...: print(f"{xp.__name__} check")
...: except:
...: print(f"{xp.__name__} nope")
...:
numpy check
torch nope
jax.numpy check
cupy nope
dask.array nope
array_api_strict check
There was a problem hiding this comment.
Done - I mentioned array_api_compat.array_namespace() as well there.
spec/2025.12/migration_guide.md
Outdated
| namespace to use: `xp = arr.__array_namespace__()` | ||
| - Each function you implement can have a namespace `xp` as a parameter in the | ||
| signature. Then enforcing inputs to be of type by the provided backend can be | ||
| achieved with `arg1 = xp.asarray(arg1)` for each input array. |
There was a problem hiding this comment.
If this is a first tutorial for a beginner, I think it could be helpful to be a bit more descriptive and a bit more opinionated. Roughly,
- for a function which accepts array arguments, use
def func(array1, scalar1, scalar2):
xp = array1.__array_namespace__() # or array_namespace(array1)
return xp.arange(scalar1, scalar2) @ array1
- for a function that accepts scalars and returns arrays, use
def func(s1, s2, xp):
return xp.arange(s1, s2)
If you prefer to not clutter the short guide with details, consider adding them to a separate page and linking from the short guide.
spec/2025.12/tutorial_basic.md
Outdated
| At this point the actual execution depends only on `xp` namespace, | ||
| and replacing that one variable allow us to switch from e.g. NumPy arrays | ||
| to a JAX execution on a GPU. This allows us to be more flexible, and, for | ||
| example use lazy evaluation and JIT compile a loop body with JAX's JIT compilation. |
There was a problem hiding this comment.
But jax jitting this iterative function with data-dependent control flow doesn't work! Simplifying a bit the hits function to accept an array-like, and jitting --- which is what I would do first thing after reading this tutorial--- gives
In [19]: def hits(A, max_iter=100, tol=1.0e-8, normalized=True):
...: A = xp.asarray(A)
...: N = A.shape[0]
...: h = xp.full(N, 1.0 / N)
...: # Power iteration: make up to max_iter iterations
...: for _i in range(max_iter):
...: hprev = h
...: a = hprev @ A
...: h = A @ a
...: h = h / xp.max(h)
...: if is_converged(hprev, h, N, tol):
...: break
...: else:
...: raise Exception("Didn't converge")
...: if normalized:
...: h = h / xp.sum(xp.abs(h))
...: a = a / xp.sum(xp.abs(a))
...: return h, a
...:
...: def is_converged(xprev, x, N, tol):
...: err = xp.sum(xp.abs(x - xprev))
...: return err < xp.asarray(N * tol)
...:
In [20]: import jax.numpy as xp
In [21]: jax.jit(hits)(jnp.eye(3))
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
Cell In[21], line 1
----> 1 jax.jit(hits)(jnp.eye(3))
[... skipping hidden 13 frame]
Cell In[19], line 11, in hits(A, max_iter, tol, normalized)
9 h = A @ a
10 h = h / xp.max(h)
---> 11 if is_converged(hprev, h, N, tol):
12 break
13 else:
[... skipping hidden 1 frame]
File ~/.conda/envs/scipy-dev/lib/python3.12/site-packages/jax/_src/core.py:1806, in concretization_function_error.<locals>.error(self, arg)
1805 def error(self, arg):
-> 1806 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function hits at <ipython-input-19-d5a4e616b039>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument A.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
There was a problem hiding this comment.
Sorry for not being clear in the text. It says JIT compile a loop body with JAX's JIT compilation, so to JIT compile what is in the for _i in range(...). That is the hot part that we want to compile into a single kernel.
It's the same premise as in the JAX documentation:
To make it fully digestible by a beginner user I embedded a copy-paste code at the bottom with JAX jit compiled version. Is it sufficiently clear?
| ``` | ||
|
|
||
| And last but not least, let's ensure that the result of the convergence | ||
| condition is a scalar coming from our API: |
There was a problem hiding this comment.
I honestly don't understand this statement. err is an array, you are converting the r.h.s. to a 0D too, so is_converged function returns a 0D array, and if is_converged(...) relies on an automagic bool(0D_array) conversion. Which would also synchronize a device when run on a GPU.
There was a problem hiding this comment.
Here rhs and lhs are 0D arrays in the current form, and the result is a 0D so that it can be used as if condition.
We're outside of JIT compiled part (as explained in the previous comment) so we don't need to worry about synchronization.
This PR adds a migration guide (versions for array consumers and producers) and one migration tutorial showing how a simple power-iteration based algorithm from GraphBLAS can be moved to an Array API compatible version.