Skip to content

[BugFix] PolicyVersion: int64 dtype + preserve tensordict device#3755

Merged
vmoens merged 4 commits into
gh/vmoens/271/basefrom
gh/vmoens/271/head
May 18, 2026
Merged

[BugFix] PolicyVersion: int64 dtype + preserve tensordict device#3755
vmoens merged 4 commits into
gh/vmoens/271/basefrom
gh/vmoens/271/head

Conversation

@vmoens
Copy link
Copy Markdown
Collaborator

@vmoens vmoens commented May 15, 2026

Stack from ghstack (oldest at bottom):

The integer branch was casting the version through float() and
calling torch.full(shape, value) without a device, so the produced
tensor was a CPU float tensor regardless of the surrounding tensordict.
This broke any path that compared the version against an integer dtype
or that lived on CUDA.

Read the tensordict's device and pass dtype=torch.int64 (with
device only when the tensordict has one, since CPU tensordicts can
have device=None).

Add a regression test pinning the dtype, shape, and device of the
emitted policy_version tensor.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 15, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3755

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 3 New Failures, 3 Unrelated Failures

As of commit 938b896 with merge base 0a01ee8 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vmoens added a commit that referenced this pull request May 18, 2026
The integer branch was casting the version through ``float()`` and
calling ``torch.full(shape, value)`` without a device, so the produced
tensor was a CPU float tensor regardless of the surrounding tensordict.
This broke any path that compared the version against an integer dtype
or that lived on CUDA.

Read the tensordict's device and pass ``dtype=torch.int64`` (with
``device`` only when the tensordict has one, since CPU tensordicts can
have ``device=None``).

Add a regression test pinning the dtype, shape, and device of the
emitted ``policy_version`` tensor.

ghstack-source-id: 66bea0a
Pull-Request: #3755
@vmoens vmoens merged commit 938b896 into gh/vmoens/271/base May 18, 2026
105 of 113 checks passed
@vmoens vmoens deleted the gh/vmoens/271/head branch May 18, 2026 21:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

BugFix CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. llm/ LLM-related PR, triggers LLM CI tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant