Skip to content

[JAX] Sharding constraint for rng_state in fused_attn_fwd makes it is impossible to use in shard_map #2500

@qGentry

Description

@qGentry

Is your feature request related to a problem? Please describe.

Hey guys, We have a custom transformer implementation with some parts wrapped in shard_map but using with_sharding_constraint here

rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None))

makes it impossible to use fused_attn_fwd in shard_map. And this is basically the only row that makes it impossible. To solve this issue, I just copy-paste source code for the following functions:

fused_attn_fwd
_fused_attn_fwd_rule
_fused_attn
fused_attn

while remove this particular row and it makes it possible to properly use cudnn attention in shard_map.

Describe the solution you'd like

I wonder if we can move this sharding constraint somewhere else, outside of fused attention fwd?

Metadata

Metadata

Assignees

No one assigned

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions