Skip to content

docs: add migration guides and tutorial#999

Open
mtsokol wants to merge 3 commits intodata-apis:mainfrom
mtsokol:tutorials-and-guides
Open

docs: add migration guides and tutorial#999
mtsokol wants to merge 3 commits intodata-apis:mainfrom
mtsokol:tutorials-and-guides

Conversation

@mtsokol
Copy link
Contributor

@mtsokol mtsokol commented Mar 13, 2026

This PR adds a migration guide (versions for array consumers and producers) and one migration tutorial showing how a simple power-iteration based algorithm from GraphBLAS can be moved to an Array API compatible version.

Copy link
Member

@ev-br ev-br left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice.

I've a left several comments, mostly very minor. There are two themes:

  • for the migration guide, I think it would be helpful to be some more specific and some more opinionated.
  • In the tutorial, I love the power iteration example. The only concern is that you talk about JIT, but the example does not jit, not easily at least. So maybe either expand a bit on how to actually jit it, or state that not everything is benefits from jitting and add one more example which does?

`array-api-strict` is a library that provides a strict and minimal
implementation of the Array API Standard. It is designed to be used as
a reference implementation for testing and development purposes. By comparing
your API calls with `array-api-strict` counterparts, you can ensure that your
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This suggestion reads as aimed at array API producers. Consider expanding to mention that an array API consumer can run their test suite with array-api-strict as a producer.

As a larger edit, consider briefly stressing in these short descriptions that array-api-compat is mainly for consumer's code, array-api-strict is a testing implement, also mostly for a consumer; array-api-tests is for a producer. While you are discussing this in more detail below, I think it'd be helpful to stress it in the very beginning, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - I added a User group clause to each library and mentioned how producers/consumers use it.

3. Inside the `array-api-tests` directory run `pytest` command. There are
multiple useful options delivered by the test suite, a few worth mentioning:
- `--max-examples=2` - maximal number of test cases to generate by the
hypothesis. This allows you to balance between execution time of the test
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: add a link to https://hypothesis.readthedocs.io/en/latest/ + a couple of sentensies to explicitly state that hypothesis generates this many examples, where each example is a valid combination of inputs; and that the recommendation is to use as many examples as reasonable for the time budget.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

- `-o xfail_strict=<bool>` is often used with the previous one. If a test
expected to fail actually passes (`XPASS`) then you can decide whether
to ignore that fact or raise it as an error.
- `--skips-file` for skipping files. At times some failing tests might stall
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- `--skips-file` for skipping files. At times some failing tests might stall
- `--skips-file` for skipping tests. At times some failing tests might stall

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


- If you are building a library where the backend is determined by input arrays
passed by the end-user, then a recommended way is to ask your input arrays for a
namespace to use: `xp = arr.__array_namespace__()`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consumers, the support is less than ideal, consider mentioning array-api-compat and array_namespace function already?

In [9]: for xp in [np, torch, jnp, cupy, da, array_api_strict]:
   ...:     try:
   ...:         xp.ones(3).__array_namespace__()
   ...:         print(f"{xp.__name__}  check")
   ...:     except:
   ...:         print(f"{xp.__name__}  nope")
   ...: 
numpy  check
torch  nope
jax.numpy  check
cupy  nope
dask.array  nope
array_api_strict  check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - I mentioned array_api_compat.array_namespace() as well there.

namespace to use: `xp = arr.__array_namespace__()`
- Each function you implement can have a namespace `xp` as a parameter in the
signature. Then enforcing inputs to be of type by the provided backend can be
achieved with `arg1 = xp.asarray(arg1)` for each input array.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a first tutorial for a beginner, I think it could be helpful to be a bit more descriptive and a bit more opinionated. Roughly,

  • for a function which accepts array arguments, use
def func(array1, scalar1, scalar2):
    xp = array1.__array_namespace__()     # or array_namespace(array1)
    return xp.arange(scalar1, scalar2) @ array1
  • for a function that accepts scalars and returns arrays, use
def func(s1, s2, xp):
   return xp.arange(s1, s2)

If you prefer to not clutter the short guide with details, consider adding them to a separate page and linking from the short guide.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

At this point the actual execution depends only on `xp` namespace,
and replacing that one variable allow us to switch from e.g. NumPy arrays
to a JAX execution on a GPU. This allows us to be more flexible, and, for
example use lazy evaluation and JIT compile a loop body with JAX's JIT compilation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But jax jitting this iterative function with data-dependent control flow doesn't work! Simplifying a bit the hits function to accept an array-like, and jitting --- which is what I would do first thing after reading this tutorial--- gives

In [19]: def hits(A, max_iter=100, tol=1.0e-8, normalized=True):
    ...:     A = xp.asarray(A)
    ...:     N = A.shape[0]
    ...:     h = xp.full(N, 1.0 / N)
    ...:     # Power iteration: make up to max_iter iterations
    ...:     for _i in range(max_iter):
    ...:         hprev = h
    ...:         a = hprev @ A
    ...:         h = A @ a
    ...:         h = h / xp.max(h)
    ...:         if is_converged(hprev, h, N, tol):
    ...:             break
    ...:     else:
    ...:         raise Exception("Didn't converge")
    ...:     if normalized:
    ...:         h = h / xp.sum(xp.abs(h))
    ...:         a = a / xp.sum(xp.abs(a))
    ...:     return h, a
    ...: 
    ...: def is_converged(xprev, x, N, tol):
    ...:     err = xp.sum(xp.abs(x - xprev))
    ...:     return err < xp.asarray(N * tol)
    ...: 

In [20]: import jax.numpy as xp

In [21]: jax.jit(hits)(jnp.eye(3))
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[21], line 1
----> 1 jax.jit(hits)(jnp.eye(3))

    [... skipping hidden 13 frame]

Cell In[19], line 11, in hits(A, max_iter, tol, normalized)
      9     h = A @ a
     10     h = h / xp.max(h)
---> 11     if is_converged(hprev, h, N, tol):
     12         break
     13 else:

    [... skipping hidden 1 frame]

File ~/.conda/envs/scipy-dev/lib/python3.12/site-packages/jax/_src/core.py:1806, in concretization_function_error.<locals>.error(self, arg)
   1805 def error(self, arg):
-> 1806   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function hits at <ipython-input-19-d5a4e616b039>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument A.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

Copy link
Contributor Author

@mtsokol mtsokol Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for not being clear in the text. It says JIT compile a loop body with JAX's JIT compilation, so to JIT compile what is in the for _i in range(...). That is the hot part that we want to compile into a single kernel.
It's the same premise as in the JAX documentation:

image

To make it fully digestible by a beginner user I embedded a copy-paste code at the bottom with JAX jit compiled version. Is it sufficiently clear?

```

And last but not least, let's ensure that the result of the convergence
condition is a scalar coming from our API:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly don't understand this statement. err is an array, you are converting the r.h.s. to a 0D too, so is_converged function returns a 0D array, and if is_converged(...) relies on an automagic bool(0D_array) conversion. Which would also synchronize a device when run on a GPU.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here rhs and lhs are 0D arrays in the current form, and the result is a 0D so that it can be used as if condition.
We're outside of JIT compiled part (as explained in the previous comment) so we don't need to worry about synchronization.

@mtsokol mtsokol requested a review from ev-br March 17, 2026 16:49
@kgryte kgryte changed the title Add migration guides and tutorial docs: add migration guides and tutorial Mar 18, 2026
@kgryte kgryte added the Narrative Content Narrative documentation content. label Mar 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Narrative Content Narrative documentation content.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants