Skip to content

[JAX] Fix get_seqlens_and_offsets() to accept vmapped seg ids and non vmapped seg offsets#2692

Open
KshitijLakhani wants to merge 13 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/fix/vmap-get-seg-ids-pos
Open

[JAX] Fix get_seqlens_and_offsets() to accept vmapped seg ids and non vmapped seg offsets#2692
KshitijLakhani wants to merge 13 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/fix/vmap-get-seg-ids-pos

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Feb 19, 2026

Description

What is the bug ?

TE provides a convenience function from_segment_ids_and_pos() which allows users to pass only segment ids and the function returns a SequenceDescriptor with internally generated segment pos and passed segment ids.

As mentioned in Issue #2685 , if a user were to vmap a function forward() which i) accepts the q,k,v,segment ids and then ii) calls from_segment_ids_and_pos() followed by iii) a call to DPA(), what happens is that JAX sees the segment ids as vmapped hence an extra leading dimension is added (e.g. 1,2,128) whereas the segment offsets are not given a leading dimension (e.g. 2,128). This results in the FusedAttn primitive impl() assert being triggered due to a shape mismatch between seg ids and seg pos as mentioned in issue #2685

What is the root cause for the bug ?

On debugging, it can be seen that the shape starts differing when the batcher is being traced for the FusedAttn primitive.
segment_ids in the primitive: treated as vmapped inputs hence batched → (1, 2, 128).
segment_pos in the primitive: treated as derived within the function hence not batched → (2, 128).

Fixes #2685

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

There are two possible approaches to solve this:

  1. Ensure that the issue is resolved at the source, i.e.. ensure that segment_pos has the same leading batching dims as segment_ids. Add any additional dims in the batcher for the same so that when impl() sees the shape they are the same. Pros: Issue resolved in a "JAX" way and at source. Cons: Increasing mem be expanding seg pos dims.
  2. Resolve the issue when impl() is called, i.e. accomodate for mismatched seg id and seg pos dims when generating the seqlens and offsets. Pros: No extra mem needed as no expansion of dims. Cons: Not "truely" solved (at source)

Second approach is chosen here as it more optimized. After this PR merge the end user can vmap wrap the TE API calls without worrying about the batching in TE.
Accomodate for

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

… the TE constructed segment pos are not thereby causing mismatches in impl()

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani self-assigned this Feb 19, 2026
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1 L2

@KshitijLakhani KshitijLakhani marked this pull request as ready for review February 20, 2026 06:54
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 20, 2026

Greptile Summary

This PR fixes a shape mismatch crash that occurred when a user wrapped a TE attention call in jax.vmap: JAX would batch segment_ids (adding an extra leading dim, e.g. (1, 2, 128)) while segment_pos, derived inside the function via from_segment_ids_and_pos(), remained un-batched (e.g. (2, 128)), triggering an assertion in the FusedAttn primitive impl.

The fix removes the strict shape-equality assertions on segment_ids/segment_pos and instead, when extra leading dims are detected on segment_ids, flattens them, calls _segment_ids_pos_to_seqlens_offsets via an inner jax.vmap (broadcasting segment_pos), and reshapes the results back. Comments are also added to the primitive batchers explaining the intentional pass-through.

Key observations from the review:

  • jnp.prod(extra_batch_shape_q) returns a float (1.0) when extra_batch_shape_q is an empty tuple () (the case where only kv, but not q, has extra dims). Passing this float to reshape raises a TypeError. Using math.prod keeps the result a Python int.
  • There is no explicit check that extra_batch_shape_q == extra_batch_shape_kv. When they differ numerically the jax.vmap call raises an opaque JAX error; when they're numerically equal but structurally different the reshape produces inconsistently shaped q/kv outputs. An explicit assertion before the vmap would give a clear, actionable error.

Confidence Score: 3/5

  • The common vmap use-case (symmetric batching of q and kv) works correctly, but two edge cases in the new reshape path can raise confusing errors or silently produce inconsistent shapes.
  • The primary bug fix is logically sound for the described use-case (symmetric vmap of q and kv segment_ids). However, jnp.prod on an empty shape tuple returns a float 1.0, which will cause a TypeError in reshape for the asymmetric case (only one of q/kv vmapped). Additionally, there is no validation that the extra batch shapes of q and kv match, which can lead to either an opaque JAX error or silently inconsistent output shapes. No new tests were added to cover the vmapped scenario.
  • transformer_engine/jax/attention.py — specifically the reshape path at lines 739-763 of get_seqlens_and_offsets().

Important Files Changed

Filename Overview
transformer_engine/jax/attention.py Core fix: get_seqlens_and_offsets() now handles vmapped segment_ids with non-vmapped segment_pos by flattening extra leading dims, running jax.vmap over _segment_ids_pos_to_seqlens_offsets, then reshaping back. Two issues: jnp.prod() on an empty shape tuple returns a float causing a TypeError in reshape for asymmetric vmap cases, and there is no validation that q/kv extra batch shapes match before calling vmap.
transformer_engine/jax/cpp_extensions/attention.py Minor documentation-only change: added comments to FusedAttnFwdPrimitive.batcher and FusedAttnBwdPrimitive.batcher explaining why the pass-through is safe when segment_ids and segment_pos have mismatched batch dims. No logic changes introduced.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["get_seqlens_and_offsets()"] --> B{"segment_ids empty"}
    B -- yes --> C["Return self.seqlens / self.seq_offsets"]
    B -- no --> D{"ids.ndim &lt; pos.ndim"}
    D -- yes --> E["Raise: ids must not have fewer dims than pos"]
    D -- no --> F{"trailing shapes match"}
    F -- no --> G["Raise: pos trailing shape must match ids"]
    F -- yes --> H{"qkv_layout.is_thd()"}
    H -- no --> I["_segment_ids_to_seqlens BSHD path"]
    H -- yes --> J{"ids.ndim &gt; pos.ndim"}
    J -- no --> K["_segment_ids_pos_to_seqlens_offsets direct call"]
    J -- yes --> L["Compute n_extra_batch_dims for q and kv"]
    L --> M["Flatten extra batch dims via reshape: q_flat / kv_flat"]
    M --> N["jax.vmap(_segment_ids_pos_to_seqlens_offsets) in_axes=(0,0,None,None)"]
    N --> O["Reshape outputs back to original extra_batch_shape"]
    O --> P["Return seqlens / offsets"]
    K --> P
    I --> P
Loading

Comments Outside Diff (2)

  1. transformer_engine/jax/attention.py, line 739-747 (link)

    jnp.prod on empty shape tuple produces a float, breaking reshape

    When only one of q/kv has extra batch dims (asymmetric vmap), the side without extra dims ends up with extra_batch_shape = (). Calling jnp.prod(()) follows NumPy semantics and returns 1.0 (a float scalar), causing a TypeError in the subsequent reshape call:

    TypeError: Shapes must be 1D sequences of concrete values of integer type, got (1.0, ...)

    For example, if n_extra_batch_dims_q == 0 (q is not vmapped but kv is), then extra_flat_batch_size_q = jnp.prod(()) = 1.0 and q_segment_ids.reshape(1.0, ...) will fail.

    Use math.prod (Python 3.8+) to keep shape arithmetic in the integer domain:

    (The import math should be moved to the top of the file alongside the other imports.)

  2. transformer_engine/jax/attention.py, line 731-741 (link)

    Missing validation for asymmetric extra batch shapes

    The guard condition at line 731 is satisfied when either q or kv has extra leading dims. If their extra batch shapes differ (e.g. extra_batch_shape_q = (4,) vs extra_batch_shape_kv = (2, 2)), the code will:

    1. Produce extra_flat_batch_size_q = 4 and extra_flat_batch_size_kv = 4 (equal numerically), so the jax.vmap call won't raise…
    2. …but the subsequent reshape at lines 760-763 reconstructs q_seqlens with shape (4, ...) and kv_seqlens with shape (2, 2, ...), leaving downstream code with inconsistently-shaped batch outputs.

    If the batch sizes also differ numerically, the jax.vmap call at line 756 will raise an opaque JAX shape-mismatch error. An explicit pre-check immediately below the computation of extra_batch_shape_* would make this far easier to debug:

    if extra_batch_shape_q != extra_batch_shape_kv:
        raise AssertionError(
            "q and kv segment_ids must have the same extra leading batch shape; got"
            f" extra_batch_shape_q={extra_batch_shape_q},"
            f" extra_batch_shape_kv={extra_batch_shape_kv}"
        )

Last reviewed commit: 0129045

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

for _ in range(leading_bdim):
expanded = lax.expand_dims(expanded, (0,))
batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape)
updated_batch_dims[seg_pos_idx] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider using seg_id_bdim instead of hardcoding 0 for consistency, even though check_valid_batch_dims ensures it's always 0

Suggested change
updated_batch_dims[seg_pos_idx] = 0
updated_batch_dims[seg_pos_idx] = seg_id_bdim

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

for _ in range(leading_bdim):
expanded = lax.expand_dims(expanded, (0,))
batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape)
updated_batch_dims[seg_pos_idx] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider using seg_id_bdim instead of hardcoding 0 for consistency, even though check_valid_batch_dims ensures it's always 0

Suggested change
updated_batch_dims[seg_pos_idx] = 0
updated_batch_dims[seg_pos_idx] = seg_id_bdim

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

…rts.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
… the seqlens and offsets for fused attn

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Comment on lines +739 to +741
# assert flat_batch_q == flat_batch_kv, (
# f"segment_ids batch size mismatch: {batch_shape_q} vs {batch_shape_kv}"
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented assertion could lead to unclear error if q and kv have mismatched batch sizes. vmap would fail but with a generic JAX error. consider uncommenting or adding a comment explaining why validation isn't needed

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

…ed to get_seqlens_and_offsets()

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/vmap-get-seg-ids-pos branch from a7c398c to 395ac54 Compare February 27, 2026 18:36
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/vmap-get-seg-ids-pos branch from 386a633 to 693ba65 Compare February 27, 2026 19:19
@KshitijLakhani KshitijLakhani changed the title [JAX] Fix batcher in FusedAttn primitive for when seg ids bdims != seg pos bdims [JAX] Fix get_seqlens_and_offsets() to accept vmapped seg ids and non vmapped seg offsets Feb 27, 2026
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1 L2

@KshitijLakhani
Copy link
Collaborator Author

CI passes. The only one failure is due to HF requests for the A100 L2 test.
Rerunning passes these

extra_flat_batch_size_kv, *kv_segment_ids.shape[n_extra_batch_dims_kv:]
)

def single_extra_batch(seg_id_q, seg_id_kv, seg_pos_q, seg_pos_kv):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self-Resolve: Is single_extra_batch only for capturing the additional static args? If so, what do you think of

single_extra_batch = functools.partial(_segment_ids_pos_to_seqlens_offsets, attn_mask_type=attn_mask_type, window_size=window_size, max_segments_per_seq=max_segments_per_seq)

to make this more concise and not require nested scopes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Thanks for catching
Made the change here: 0129045

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

JAX vmap issue with TE Attention

2 participants