Skip to content

[Performance] LSTM/GRU scan: canonical strides + cuDNN flat-storage clones + thread-local recurrent mode#3754

Merged
vmoens merged 4 commits into
gh/vmoens/270/basefrom
gh/vmoens/270/head
May 18, 2026
Merged

[Performance] LSTM/GRU scan: canonical strides + cuDNN flat-storage clones + thread-local recurrent mode#3754
vmoens merged 4 commits into
gh/vmoens/270/basefrom
gh/vmoens/270/head

Conversation

@vmoens
Copy link
Copy Markdown
Collaborator

@vmoens vmoens commented May 15, 2026

Stack from ghstack (oldest at bottom):

Three intertwined fixes to the scan / triton recurrent backends.

  • Canonical-stride check. A [1, 4, 5] tensor with strides (5, 5, 1)
    passes is_contiguous() but torch._higher_order_ops.scan and the
    triton kernels read strides directly and reject non-canonical layouts.
    Add _canonical_stride + _canonical_contiguous and re-materialize
    inputs / hidden buffers when strides drift off the C-canonical layout.

  • cuDNN flat-storage aliasing. nn.LSTM / nn.GRU with cuDNN flatten
    all per-layer parameters into a single storage; the scan HOP tracer
    walks the FakeTensor graph and rejects the aliased per-layer views as
    inputs. Clone the weight views before closing the scan body. The
    per-layer carry now also clones x_t and the transpose+flatten output
    (the only remaining aliasing edge) so the existing .clone() on the
    full torch.stack(...) carry can drop.

  • Thread-local recurrent_mode. _ContextManager was a single mutable
    module-level flag, so spawning a collector worker thread saw the
    parent's recurrent_mode setting. Wrap in _RecurrentModeContextManager
    using contextvars.ContextVar so per-thread state is isolated.

Tests cover: _canonical_contiguous no-op vs re-materialize, scan/triton
output parity under non-canonical strides, scan-under-torch.compile
aliasing pin for both LSTM and GRU, LSTM scan/pad PPO-advantage parity,
non-canonical hidden-buffer regression for scan, and a thread-local
set_recurrent_mode test.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 15, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3754

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 3 New Failures, 1 Cancelled Job

As of commit 244f61c with merge base 0a01ee8 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Comment on lines +1137 to +1138
hidden_per_step = _canonical_contiguous(hidden0_in[..., layer, :])
cell_per_step = _canonical_contiguous(hidden1_in[..., layer, :])
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we were doing that for hidden0_in before the loop? Would the view still be contiguous?

b_hh = zeros if b_hh is None else b_hh

hidden_per_step = hidden_in[..., layer, :]
hidden_per_step = _canonical_contiguous(hidden_in[..., layer, :])
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, should we do it before the loop?

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vmoens added a commit that referenced this pull request May 18, 2026
…lones + thread-local recurrent mode

Three intertwined fixes to the scan / triton recurrent backends.

- Canonical-stride check. A `[1, 4, 5]` tensor with strides `(5, 5, 1)`
  passes `is_contiguous()` but `torch._higher_order_ops.scan` and the
  triton kernels read strides directly and reject non-canonical layouts.
  Add `_canonical_stride` + `_canonical_contiguous` and re-materialize
  inputs / hidden buffers when strides drift off the C-canonical layout.

- cuDNN flat-storage aliasing. `nn.LSTM` / `nn.GRU` with cuDNN flatten
  all per-layer parameters into a single storage; the scan HOP tracer
  walks the FakeTensor graph and rejects the aliased per-layer views as
  inputs. Clone the weight views before closing the scan body. The
  per-layer carry now also clones `x_t` and the transpose+flatten output
  (the only remaining aliasing edge) so the existing `.clone()` on the
  full `torch.stack(...)` carry can drop.

- Thread-local `recurrent_mode`. `_ContextManager` was a single mutable
  module-level flag, so spawning a collector worker thread saw the
  parent's recurrent_mode setting. Wrap in `_RecurrentModeContextManager`
  using `contextvars.ContextVar` so per-thread state is isolated.

Tests cover: `_canonical_contiguous` no-op vs re-materialize, scan/triton
output parity under non-canonical strides, scan-under-torch.compile
aliasing pin for both LSTM and GRU, LSTM scan/pad PPO-advantage parity,
non-canonical hidden-buffer regression for scan, and a thread-local
`set_recurrent_mode` test.

ghstack-source-id: 40fca02
Pull-Request: #3754
Co-authored-by: Cursor <cursoragent@cursor.com>
@vmoens vmoens merged commit 244f61c into gh/vmoens/270/base May 18, 2026
107 of 113 checks passed
@vmoens vmoens deleted the gh/vmoens/270/head branch May 18, 2026 21:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Integrations/torch_geometric Integrations Modules Performance Performance issue or suggestion for improvement

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant