Skip to content

[Feature] WorldModel + WorldModelLoss — general model-based RL abstraction#3783

Open
theap06 wants to merge 1 commit into
pytorch:mainfrom
theap06:feat/World_Model
Open

[Feature] WorldModel + WorldModelLoss — general model-based RL abstraction#3783
theap06 wants to merge 1 commit into
pytorch:mainfrom
theap06:feat/World_Model

Conversation

@theap06
Copy link
Copy Markdown
Contributor

@theap06 theap06 commented May 20, 2026

Fixes #3774

  • Adds `WorldModel(TensorDictModuleBase)`: a key-driven, architecture-agnostic composition layer for encoder + dynamics + reward/done/decoder heads, with `encode`, `step`, `decode` shortcuts and a `rollout(start_td, policy, horizon)` method whose `[batch, horizon]` output matches `EnvBase.rollout` — making imagined trajectories drop-in compatible with replay buffers, GAE, and loss modules.
  • Adds `WorldModelLoss(LossModule)`: follows the standard `_AcceptedKeys` / `set_keys()` / `forward() → TensorDict` pattern; supports configurable sub-losses (`reward`, `done`, `reconstruction`, `latent`) with per-loss weights and distance-function choice.
  • Existing Dreamer components (`WorldModelWrapper`, `DreamerEnv`, `RSSMRollout`, `DreamerModelLoss/ActorLoss/ValueLoss`) are unchanged.

Motivation

TorchRL ships Dreamer-specific model-based components but no general abstraction. Users implementing MBPO, TD-MPC, PlaNet, or any custom world model must hand-wire `TensorDictModule` chains and write bespoke multi-step rollout loops. This adds the missing general layer.

Test plan

  • `python -m pytest test/test_world_model.py -v` — 19 new tests covering forward, encode/step/decode shortcuts, nested keys, rollout shape and early termination, replay buffer compatibility, all four loss types, `set_keys`, per-loss weights, and gradient flow
  • `python -c "from torchrl.modules import WorldModel; from torchrl.objectives import WorldModelLoss; print('OK')"` — import smoke test
  • Existing dreamer-related tests (`pytest test/ -k "dreamer" --ignore=test/llm`) — 409 tests, all passing

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 20, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3783

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures

As of commit ddb8e36 with merge base eb90c5d (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 20, 2026
@github-actions
Copy link
Copy Markdown
Contributor

⚠️ PR Title Label Error

PR title must start with a label prefix in brackets (e.g., [BugFix]).

Current title: WorldModel + WorldModelLoss — general model-based RL abstraction

Supported Prefixes (case-sensitive)

Your PR title must start with exactly one of these prefixes:

Prefix Label Applied Example
[BugFix] BugFix [BugFix] Fix memory leak in collector
[Feature] Feature [Feature] Add new optimizer
[Doc] or [Docs] Documentation [Doc] Update installation guide
[Refactor] Refactoring [Refactor] Clean up module imports
[CI] CI [CI] Fix workflow permissions
[Test] or [Tests] Tests [Tests] Add unit tests for buffer
[Environment] or [Environments] Environments [Environments] Add Gymnasium support
[Data] Data [Data] Fix replay buffer sampling
[Performance] or [Perf] Performance [Performance] Optimize tensor ops
[BC-Breaking] bc breaking [BC-Breaking] Remove deprecated API
[Deprecation] Deprecation [Deprecation] Mark old function
[Quality] Quality [Quality] Fix typos and add codespell

Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).

@github-actions
Copy link
Copy Markdown
Contributor

⚠️ PR Title Label Error

PR title must start with a label prefix in brackets (e.g., [BugFix]).

Current title: WorldModel + WorldModelLoss — general model-based RL abstraction

Supported Prefixes (case-sensitive)

Your PR title must start with exactly one of these prefixes:

Prefix Label Applied Example
[BugFix] BugFix [BugFix] Fix memory leak in collector
[Feature] Feature [Feature] Add new optimizer
[Doc] or [Docs] Documentation [Doc] Update installation guide
[Refactor] Refactoring [Refactor] Clean up module imports
[CI] CI [CI] Fix workflow permissions
[Test] or [Tests] Tests [Tests] Add unit tests for buffer
[Environment] or [Environments] Environments [Environments] Add Gymnasium support
[Data] Data [Data] Fix replay buffer sampling
[Performance] or [Perf] Performance [Performance] Optimize tensor ops
[BC-Breaking] bc breaking [BC-Breaking] Remove deprecated API
[Deprecation] Deprecation [Deprecation] Mark old function
[Quality] Quality [Quality] Fix typos and add codespell

Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).

2 similar comments
@github-actions
Copy link
Copy Markdown
Contributor

⚠️ PR Title Label Error

PR title must start with a label prefix in brackets (e.g., [BugFix]).

Current title: WorldModel + WorldModelLoss — general model-based RL abstraction

Supported Prefixes (case-sensitive)

Your PR title must start with exactly one of these prefixes:

Prefix Label Applied Example
[BugFix] BugFix [BugFix] Fix memory leak in collector
[Feature] Feature [Feature] Add new optimizer
[Doc] or [Docs] Documentation [Doc] Update installation guide
[Refactor] Refactoring [Refactor] Clean up module imports
[CI] CI [CI] Fix workflow permissions
[Test] or [Tests] Tests [Tests] Add unit tests for buffer
[Environment] or [Environments] Environments [Environments] Add Gymnasium support
[Data] Data [Data] Fix replay buffer sampling
[Performance] or [Perf] Performance [Performance] Optimize tensor ops
[BC-Breaking] bc breaking [BC-Breaking] Remove deprecated API
[Deprecation] Deprecation [Deprecation] Mark old function
[Quality] Quality [Quality] Fix typos and add codespell

Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).

@github-actions
Copy link
Copy Markdown
Contributor

⚠️ PR Title Label Error

PR title must start with a label prefix in brackets (e.g., [BugFix]).

Current title: WorldModel + WorldModelLoss — general model-based RL abstraction

Supported Prefixes (case-sensitive)

Your PR title must start with exactly one of these prefixes:

Prefix Label Applied Example
[BugFix] BugFix [BugFix] Fix memory leak in collector
[Feature] Feature [Feature] Add new optimizer
[Doc] or [Docs] Documentation [Doc] Update installation guide
[Refactor] Refactoring [Refactor] Clean up module imports
[CI] CI [CI] Fix workflow permissions
[Test] or [Tests] Tests [Tests] Add unit tests for buffer
[Environment] or [Environments] Environments [Environments] Add Gymnasium support
[Data] Data [Data] Fix replay buffer sampling
[Performance] or [Perf] Performance [Performance] Optimize tensor ops
[BC-Breaking] bc breaking [BC-Breaking] Remove deprecated API
[Deprecation] Deprecation [Deprecation] Mark old function
[Quality] Quality [Quality] Fix typos and add codespell

Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).

@theap06 theap06 changed the title WorldModel + WorldModelLoss — general model-based RL abstraction [Feature] WorldModel + WorldModelLoss — general model-based RL abstraction May 20, 2026
@github-actions github-actions Bot added the Feature New feature label May 20, 2026
@theap06
Copy link
Copy Markdown
Contributor Author

theap06 commented May 21, 2026

@vmoens @elin-bdai this was my attempt at creating a World Model abstraction. Lmk if you have any feedback!

@theap06
Copy link
Copy Markdown
Contributor Author

theap06 commented May 22, 2026

@vmoens I believe this issue is tied to the flaky tests from earlier. Lmk if the design seems sound to you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Documentation Improvements or additions to documentation Feature New feature Integrations/torch_geometric Integrations Modules Objectives

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] torchrl.modules.WorldModel — A General TensorDict-Native World Model Abstraction

1 participant