Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ strands-agents/
│ │ ├── event_loop.py # Event loop types
│ │ ├── json_dict.py # JSON dict utilities
│ │ ├── collections.py # Collection types
│ │ ├── _snapshot.py # Snapshot types and helpers
│ │ ├── _events.py # Internal event types
│ │ ├── a2a.py # A2A protocol types
│ │ └── models/ # Model-specific types
Expand Down
2 changes: 2 additions & 0 deletions src/strands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .event_loop._retry import ModelRetryStrategy
from .plugins import Plugin
from .tools.decorator import tool
from .types._snapshot import Snapshot
from .types.tools import ToolContext
from .vended_plugins.skills import AgentSkills, Skill

Expand All @@ -18,6 +19,7 @@
"ModelRetryStrategy",
"Plugin",
"Skill",
"Snapshot",
"tool",
"ToolContext",
"types",
Expand Down
90 changes: 90 additions & 0 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
"""

import copy
import logging
import threading
import warnings
Expand All @@ -29,6 +30,15 @@
from ..event_loop._retry import ModelRetryStrategy
from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle
from ..tools._tool_helpers import generate_missing_tool_result_content
from ..types._snapshot import (
SNAPSHOT_SCHEMA_VERSION,
Snapshot,
SnapshotField,
SnapshotPreset,
TakeSnapshotOptions,
_utc_now_iso,
resolve_snapshot_fields,
)

if TYPE_CHECKING:
from ..tools import ToolProvider
Expand Down Expand Up @@ -1039,6 +1049,86 @@ async def _append_messages(self, *messages: Message) -> None:
self.messages.append(message)
await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message))

def take_snapshot(
self,
*,
preset: SnapshotPreset | None = None,
include: list[SnapshotField] | None = None,
exclude: list[SnapshotField] | None = None,
app_data: dict[str, Any] | None = None,
) -> Snapshot:
"""Capture current agent state as an in-memory snapshot.

Args:
preset: Named preset of fields to capture. Currently only "session" is supported,
which captures messages, state, conversation_manager_state, and interrupt_state.
include: Additional fields to capture on top of the preset.
exclude: Fields to remove after applying preset and include.
app_data: Application-owned arbitrary JSON stored verbatim in the snapshot.

Returns:
A Snapshot containing the captured agent state.

Raises:
SnapshotException: If no fields are resolved or an invalid field name is provided.
"""
options: TakeSnapshotOptions = {}
if preset is not None:
options["preset"] = preset
if include is not None:
options["include"] = include
if exclude is not None:
options["exclude"] = exclude

fields = resolve_snapshot_fields(options)

data: dict[str, Any] = {}
if "messages" in fields:
data["messages"] = copy.deepcopy(self.messages)
if "state" in fields:
data["state"] = self.state.get()
if "conversation_manager_state" in fields:
data["conversation_manager_state"] = self.conversation_manager.get_state()
if "interrupt_state" in fields:
data["interrupt_state"] = self._interrupt_state.to_dict()
if "system_prompt" in fields:
# Store the content-block representation so round-trips preserve caching hints and
# other block-level metadata.
data["system_prompt"] = self._system_prompt_content

return Snapshot(
schema_version=SNAPSHOT_SCHEMA_VERSION,
created_at=_utc_now_iso(),
data=data,
app_data=copy.deepcopy(app_data) if app_data else {},
)

def load_snapshot(self, snapshot: Snapshot) -> None:
"""Restore agent state from a previously captured snapshot.

Only fields present in snapshot.data are restored; absent fields are left unchanged.

Args:
snapshot: The snapshot to restore from.

Raises:
SnapshotException: If snapshot.schema_version is not "1.0".
"""
snapshot.validate()

data = snapshot.data

if "messages" in data:
self.messages = copy.deepcopy(data["messages"])
if "state" in data:
self.state = AgentState(data["state"])
if "conversation_manager_state" in data:
self.conversation_manager.restore_from_session(data["conversation_manager_state"])
if "interrupt_state" in data:
self._interrupt_state = _InterruptState.from_dict(data["interrupt_state"])
if "system_prompt" in data:
self.system_prompt = copy.deepcopy(data["system_prompt"])

def _redact_user_content(self, content: list[ContentBlock], redact_message: str) -> list[ContentBlock]:
"""Redact user content preserving toolResult blocks.

Expand Down
3 changes: 2 additions & 1 deletion src/strands/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""SDK type definitions."""

from ._snapshot import Snapshot
from .collections import PaginatedList

__all__ = ["PaginatedList"]
__all__ = ["PaginatedList", "Snapshot"]
131 changes: 131 additions & 0 deletions src/strands/types/_snapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Snapshot types, constants, and helpers for agent state capture."""

from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, Literal, TypedDict

from .exceptions import SnapshotException

SnapshotField = Literal["messages", "state", "conversation_manager_state", "interrupt_state", "system_prompt"]
SnapshotPreset = Literal["session"]

ALL_SNAPSHOT_FIELDS: tuple[SnapshotField, ...] = (
"messages",
"state",
"conversation_manager_state",
"interrupt_state",
"system_prompt",
)

SNAPSHOT_SCHEMA_VERSION = "1.0"

SNAPSHOT_PRESETS: dict[str, tuple[SnapshotField, ...]] = {
"session": ("messages", "state", "conversation_manager_state", "interrupt_state"),
}


class TakeSnapshotOptions(TypedDict, total=False):
"""Internal options for take_snapshot. Not exported publicly."""

preset: SnapshotPreset
include: list[SnapshotField]
exclude: list[SnapshotField]
app_data: dict[str, Any]


@dataclass
class Snapshot:
"""Point-in-time capture of agent state as a versioned JSON-compatible object."""

schema_version: str
created_at: str # ISO 8601 UTC
data: dict[str, Any]
app_data: dict[str, Any]

def validate(self) -> None:
"""Validate that this snapshot can be loaded by the current SDK version.

Raises:
SnapshotException: If schema_version is not "1.0".
"""
if self.schema_version != SNAPSHOT_SCHEMA_VERSION:
raise SnapshotException(
f"Unsupported snapshot schema version: {self.schema_version!r}. "
f"Current version: {SNAPSHOT_SCHEMA_VERSION}"
)

def to_dict(self) -> dict[str, Any]:
"""Serialize to a plain JSON-compatible dict."""
return {
"schema_version": self.schema_version,
"created_at": self.created_at,
"data": self.data,
"app_data": self.app_data,
}

@classmethod
def from_dict(cls, d: dict[str, Any]) -> Snapshot:
"""Reconstruct a Snapshot from a dict produced by to_dict().

Raises:
SnapshotException: If schema_version is not "1.0".
"""
snapshot = cls(
schema_version=d.get("schema_version", ""),
created_at=d["created_at"],
data=d["data"],
app_data=d.get("app_data", {}),
)
snapshot.validate()
return snapshot


def resolve_snapshot_fields(options: TakeSnapshotOptions) -> set[SnapshotField]:
"""Resolve the set of fields to capture based on options.

Applies: preset → include → exclude (in that order).

Raises:
SnapshotException: If any field name is invalid or the resolved set is empty.
"""
valid = set(ALL_SNAPSHOT_FIELDS)

# Validate include/exclude field names
for field in options.get("include") or []:
if field not in valid:
raise SnapshotException(f"Invalid snapshot field: {field!r}. Valid fields: {sorted(valid)}")
for field in options.get("exclude") or []:
if field not in valid:
raise SnapshotException(f"Invalid snapshot field: {field!r}. Valid fields: {sorted(valid)}")

# Step 1: start with preset
preset = options.get("preset")
if preset is not None:
fields: set[SnapshotField] = set(SNAPSHOT_PRESETS[preset])
else:
fields = set()

# Step 2: union with include
include = options.get("include")
if include:
fields |= set(include)

# Step 3: subtract exclude
exclude = options.get("exclude")
if exclude:
fields -= set(exclude)

if not fields:
raise SnapshotException(
"No snapshot fields resolved. Provide a preset or at least one field in 'include'. "
"Note: passing only 'exclude' without a preset or 'include' always results in an empty set."
)

return fields


def _utc_now_iso() -> str:
"""Return the current UTC time as an ISO 8601 string ending in 'Z'."""
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
6 changes: 6 additions & 0 deletions src/strands/types/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ class SessionException(Exception):
pass


class SnapshotException(Exception):
"""Exception raised when snapshot operations fail (e.g., unsupported schema version)."""

pass


class ToolProviderException(Exception):
"""Exception raised when a tool provider fails to load or cleanup tools."""

Expand Down
Loading
Loading