Replace jax.sharding.use_mesh with jax.set_mesh. jax.set_mesh can act as a global setter or a context manager.#65
Merged
Commits
Commits on Jun 11, 2026
- authored andcommitted

