Add guard at lowest JAX version that still supports triton kernel calling#2741
Add guard at lowest JAX version that still supports triton kernel calling#2741tdophung wants to merge 9 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: tdophung <tdophung@nvidia.com>
|
/te-ci jax |
for more information, see https://pre-commit.ci
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Awesome, LGTM pending CI, thanks!
Greptile SummaryThis PR introduces a minimum JAX version guard ( Key changes:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["Import triton_extensions or run Triton test"] --> B["is_triton_extension_supported()\n(from version_utils)"]
B --> C{JAX >= 0.8.0?}
C -- No --> D1["triton_extensions/utils.py\nraises RuntimeError"]
C -- No --> D2["conftest.py pytest_collection_modifyitems\nmarks @triton tests as skip"]
C -- No --> D3["require_triton_or_skip_test_file()\npytest.skip(allow_module_level=True)"]
C -- Yes --> E["Normal execution\n_check_triton_compatibility()"]
E --> F{Triton installed?}
F -- No --> G["ImportError: install triton"]
F -- Yes --> H["Import gpu_triton, triton.compiler\nKernel dispatch works"]
D2 --> I["_inject_* autouse fixture\nreturns early (no inject)"]
H --> J["_inject_* autouse fixture\ninjects module-level names\n(token_dispatch, fused_topk…)"]
subgraph version_utils.py
B
end
|
…lper.py Signed-off-by: tdophung <tdophung@nvidia.com>
- Add version_utils.py with is_triton_extension_supported() checking JAX >= 0.8.0 (release version, not dev snapshot) and TRITON_EXTENSION_MIN_JAX_VERSION constant - Add pytest.mark.triton marker and conftest hook to skip marked tests on old JAX - Add require_triton() for module-level skipping in test files - Rewrite triton_extensions to use is_triton_extension_supported() instead of direct jaxlib dev-version comparison Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
…d re-export, revert test.sh - require_triton(): add allow_module_level=True to pytest.skip() so module-level calls on old JAX produce a proper skip instead of a collection failure - Remove is_triton_extension_supported from triton_extensions/utils.py __all__: importing triton_extensions on JAX < 0.8.0 raises immediately, so re-exporting the check from there defeats its purpose; callers should import directly from transformer_engine.jax.version_utils - Revert qa/L0_jax_lint/test.sh TE_PATH to /opt/transformerengine (local dev path was accidentally committed; pass TE_PATH= at invocation time instead) Signed-off-by: tdophung <tdophung@nvidia.com>
…l__ and hardcoded version - Move is_triton_extension_supported() guard before the gpu_triton import block with a comment clarifying the segfault is at dispatch time, not import time - Remove _jax_version_meet_requirement from version_utils __all__ (private helper, not a public API; callers import it explicitly as needed) - Use TRITON_EXTENSION_MIN_JAX_VERSION constant in conftest marker description instead of hardcoded '0.8.0' Signed-off-by: tdophung <tdophung@nvidia.com>
|
/te-ci jax |
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci jax |
Description
To provide backward compatibility with older jax versions, we need to have a safeguard in place for jax versions too old to work with triton kernel calling. Using Claude Code to automate bisecting through JAX toolbox nightly containers between Sep 1, 2025 and Oct 1, 2025 (*), I have found that the first passing version of the container starts on Sep 24th, 2025, corresponding to jax 0.8.0.dev20250924 hence the guard is put there.
(*) the date range is determined by having a data point that the officially released jax toolbox (nvcr.io/nvidia/jax:25.10-py3 fails while the nightly jax container on Oct 1st passed.
Fixes # (issue)
Type of change
Changes
Handles jax < 0.8.0.dev20250924 segfault error when calling triton kernels frfom JAX side
Checklist: