Is your feature request related to a problem? Please describe.
Hey guys, We have a custom transformer implementation with some parts wrapped in shard_map but using with_sharding_constraint here
|
rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None)) |
makes it impossible to use fused_attn_fwd in shard_map. And this is basically the only row that makes it impossible. To solve this issue, I just copy-paste source code for the following functions:
fused_attn_fwd
_fused_attn_fwd_rule
_fused_attn
fused_attn
while remove this particular row and it makes it possible to properly use cudnn attention in shard_map.
Describe the solution you'd like
I wonder if we can move this sharding constraint somewhere else, outside of fused attention fwd?