diff --git a/AGENTS.md b/AGENTS.md index a9a2a5044..3615e713a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -105,6 +105,7 @@ strands-agents/ │ │ ├── event_loop.py # Event loop types │ │ ├── json_dict.py # JSON dict utilities │ │ ├── collections.py # Collection types +│ │ ├── _snapshot.py # Snapshot types and helpers │ │ ├── _events.py # Internal event types │ │ ├── a2a.py # A2A protocol types │ │ └── models/ # Model-specific types diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 2078f16ce..6625ac41f 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -6,6 +6,7 @@ from .event_loop._retry import ModelRetryStrategy from .plugins import Plugin from .tools.decorator import tool +from .types._snapshot import Snapshot from .types.tools import ToolContext from .vended_plugins.skills import AgentSkills, Skill @@ -18,6 +19,7 @@ "ModelRetryStrategy", "Plugin", "Skill", + "Snapshot", "tool", "ToolContext", "types", diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f378a886a..abbbaf648 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,6 +9,7 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ +import copy import logging import threading import warnings @@ -29,6 +30,15 @@ from ..event_loop._retry import ModelRetryStrategy from ..event_loop.event_loop import INITIAL_DELAY, MAX_ATTEMPTS, MAX_DELAY, event_loop_cycle from ..tools._tool_helpers import generate_missing_tool_result_content +from ..types._snapshot import ( + SNAPSHOT_SCHEMA_VERSION, + Snapshot, + SnapshotField, + SnapshotPreset, + TakeSnapshotOptions, + _utc_now_iso, + resolve_snapshot_fields, +) if TYPE_CHECKING: from ..tools import ToolProvider @@ -1039,6 +1049,86 @@ async def _append_messages(self, *messages: Message) -> None: self.messages.append(message) await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message)) + def take_snapshot( + self, + *, + preset: SnapshotPreset | None = None, + include: list[SnapshotField] | None = None, + exclude: list[SnapshotField] | None = None, + app_data: dict[str, Any] | None = None, + ) -> Snapshot: + """Capture current agent state as an in-memory snapshot. + + Args: + preset: Named preset of fields to capture. Currently only "session" is supported, + which captures messages, state, conversation_manager_state, and interrupt_state. + include: Additional fields to capture on top of the preset. + exclude: Fields to remove after applying preset and include. + app_data: Application-owned arbitrary JSON stored verbatim in the snapshot. + + Returns: + A Snapshot containing the captured agent state. + + Raises: + SnapshotException: If no fields are resolved or an invalid field name is provided. + """ + options: TakeSnapshotOptions = {} + if preset is not None: + options["preset"] = preset + if include is not None: + options["include"] = include + if exclude is not None: + options["exclude"] = exclude + + fields = resolve_snapshot_fields(options) + + data: dict[str, Any] = {} + if "messages" in fields: + data["messages"] = copy.deepcopy(self.messages) + if "state" in fields: + data["state"] = self.state.get() + if "conversation_manager_state" in fields: + data["conversation_manager_state"] = self.conversation_manager.get_state() + if "interrupt_state" in fields: + data["interrupt_state"] = self._interrupt_state.to_dict() + if "system_prompt" in fields: + # Store the content-block representation so round-trips preserve caching hints and + # other block-level metadata. + data["system_prompt"] = self._system_prompt_content + + return Snapshot( + schema_version=SNAPSHOT_SCHEMA_VERSION, + created_at=_utc_now_iso(), + data=data, + app_data=copy.deepcopy(app_data) if app_data else {}, + ) + + def load_snapshot(self, snapshot: Snapshot) -> None: + """Restore agent state from a previously captured snapshot. + + Only fields present in snapshot.data are restored; absent fields are left unchanged. + + Args: + snapshot: The snapshot to restore from. + + Raises: + SnapshotException: If snapshot.schema_version is not "1.0". + """ + snapshot.validate() + + data = snapshot.data + + if "messages" in data: + self.messages = copy.deepcopy(data["messages"]) + if "state" in data: + self.state = AgentState(data["state"]) + if "conversation_manager_state" in data: + self.conversation_manager.restore_from_session(data["conversation_manager_state"]) + if "interrupt_state" in data: + self._interrupt_state = _InterruptState.from_dict(data["interrupt_state"]) + if "system_prompt" in data: + self.system_prompt = copy.deepcopy(data["system_prompt"]) + def _redact_user_content(self, content: list[ContentBlock], redact_message: str) -> list[ContentBlock]: """Redact user content preserving toolResult blocks. diff --git a/src/strands/types/__init__.py b/src/strands/types/__init__.py index 7eef60cb4..60d6b3a17 100644 --- a/src/strands/types/__init__.py +++ b/src/strands/types/__init__.py @@ -1,5 +1,6 @@ """SDK type definitions.""" +from ._snapshot import Snapshot from .collections import PaginatedList -__all__ = ["PaginatedList"] +__all__ = ["PaginatedList", "Snapshot"] diff --git a/src/strands/types/_snapshot.py b/src/strands/types/_snapshot.py new file mode 100644 index 000000000..fce0fd76e --- /dev/null +++ b/src/strands/types/_snapshot.py @@ -0,0 +1,131 @@ +"""Snapshot types, constants, and helpers for agent state capture.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Literal, TypedDict + +from .exceptions import SnapshotException + +SnapshotField = Literal["messages", "state", "conversation_manager_state", "interrupt_state", "system_prompt"] +SnapshotPreset = Literal["session"] + +ALL_SNAPSHOT_FIELDS: tuple[SnapshotField, ...] = ( + "messages", + "state", + "conversation_manager_state", + "interrupt_state", + "system_prompt", +) + +SNAPSHOT_SCHEMA_VERSION = "1.0" + +SNAPSHOT_PRESETS: dict[str, tuple[SnapshotField, ...]] = { + "session": ("messages", "state", "conversation_manager_state", "interrupt_state"), +} + + +class TakeSnapshotOptions(TypedDict, total=False): + """Internal options for take_snapshot. Not exported publicly.""" + + preset: SnapshotPreset + include: list[SnapshotField] + exclude: list[SnapshotField] + app_data: dict[str, Any] + + +@dataclass +class Snapshot: + """Point-in-time capture of agent state as a versioned JSON-compatible object.""" + + schema_version: str + created_at: str # ISO 8601 UTC + data: dict[str, Any] + app_data: dict[str, Any] + + def validate(self) -> None: + """Validate that this snapshot can be loaded by the current SDK version. + + Raises: + SnapshotException: If schema_version is not "1.0". + """ + if self.schema_version != SNAPSHOT_SCHEMA_VERSION: + raise SnapshotException( + f"Unsupported snapshot schema version: {self.schema_version!r}. " + f"Current version: {SNAPSHOT_SCHEMA_VERSION}" + ) + + def to_dict(self) -> dict[str, Any]: + """Serialize to a plain JSON-compatible dict.""" + return { + "schema_version": self.schema_version, + "created_at": self.created_at, + "data": self.data, + "app_data": self.app_data, + } + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> Snapshot: + """Reconstruct a Snapshot from a dict produced by to_dict(). + + Raises: + SnapshotException: If schema_version is not "1.0". + """ + snapshot = cls( + schema_version=d.get("schema_version", ""), + created_at=d["created_at"], + data=d["data"], + app_data=d.get("app_data", {}), + ) + snapshot.validate() + return snapshot + + +def resolve_snapshot_fields(options: TakeSnapshotOptions) -> set[SnapshotField]: + """Resolve the set of fields to capture based on options. + + Applies: preset → include → exclude (in that order). + + Raises: + SnapshotException: If any field name is invalid or the resolved set is empty. + """ + valid = set(ALL_SNAPSHOT_FIELDS) + + # Validate include/exclude field names + for field in options.get("include") or []: + if field not in valid: + raise SnapshotException(f"Invalid snapshot field: {field!r}. Valid fields: {sorted(valid)}") + for field in options.get("exclude") or []: + if field not in valid: + raise SnapshotException(f"Invalid snapshot field: {field!r}. Valid fields: {sorted(valid)}") + + # Step 1: start with preset + preset = options.get("preset") + if preset is not None: + fields: set[SnapshotField] = set(SNAPSHOT_PRESETS[preset]) + else: + fields = set() + + # Step 2: union with include + include = options.get("include") + if include: + fields |= set(include) + + # Step 3: subtract exclude + exclude = options.get("exclude") + if exclude: + fields -= set(exclude) + + if not fields: + raise SnapshotException( + "No snapshot fields resolved. Provide a preset or at least one field in 'include'. " + "Note: passing only 'exclude' without a preset or 'include' always results in an empty set." + ) + + return fields + + +def _utc_now_iso() -> str: + """Return the current UTC time as an ISO 8601 string ending in 'Z'.""" + return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 1d1983abd..5db80a26e 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -77,6 +77,12 @@ class SessionException(Exception): pass +class SnapshotException(Exception): + """Exception raised when snapshot operations fail (e.g., unsupported schema version).""" + + pass + + class ToolProviderException(Exception): """Exception raised when a tool provider fails to load or cleanup tools.""" diff --git a/tests/strands/agent/test_snapshot.py b/tests/strands/agent/test_snapshot.py new file mode 100644 index 000000000..c80ccfcbc --- /dev/null +++ b/tests/strands/agent/test_snapshot.py @@ -0,0 +1,287 @@ +"""Tests for _snapshot.py — Snapshot dataclass and resolve_snapshot_fields.""" + +import re +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from strands import Agent +from strands.types._snapshot import ( + ALL_SNAPSHOT_FIELDS, + SNAPSHOT_PRESETS, + SNAPSHOT_SCHEMA_VERSION, + Snapshot, + TakeSnapshotOptions, + resolve_snapshot_fields, +) +from strands.types.exceptions import SnapshotException + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +ISO_8601_UTC_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(\.\d+)?Z$") + + +def _make_snapshot(**kwargs: object) -> Snapshot: + defaults: dict[str, Any] = { + "schema_version": SNAPSHOT_SCHEMA_VERSION, + "created_at": "2025-01-15T12:00:00.000000Z", + "data": {}, + "app_data": {}, + } + defaults.update(kwargs) + return Snapshot(**defaults) + + +def _make_agent(**kwargs) -> Agent: + """Create a minimal Agent with a mock model for testing.""" + mock_model = MagicMock() + mock_model.get_config.return_value = {} + return Agent(model=mock_model, callback_handler=None, **kwargs) + + +def test_snapshot_from_dict_bad_version_raises(): + d = {"schema_version": "99.0", "created_at": "2025-01-15T12:00:00Z", "data": {}, "app_data": {}} + with pytest.raises(SnapshotException, match="Unsupported snapshot schema version"): + Snapshot.from_dict(d) + + +def test_snapshot_to_dict_round_trip(): + s = _make_snapshot(data={"messages": []}, app_data={"x": 1}) + assert Snapshot.from_dict(s.to_dict()) == s + + +def test_resolve_snapshot_fields_invalid_include_raises(): + with pytest.raises(SnapshotException, match="Invalid snapshot field"): + resolve_snapshot_fields({"include": ["not_a_field"]}) # type: ignore[typeddict-item] + + +def test_resolve_snapshot_fields_invalid_exclude_raises(): + with pytest.raises(SnapshotException, match="Invalid snapshot field"): + resolve_snapshot_fields({"preset": "session", "exclude": ["not_a_field"]}) # type: ignore[typeddict-item] + + +def test_resolve_snapshot_fields_no_preset_no_include_raises(): + with pytest.raises(SnapshotException, match="No snapshot fields resolved"): + resolve_snapshot_fields({}) + + +def test_resolve_snapshot_fields_session_preset(): + assert resolve_snapshot_fields({"preset": "session"}) == set(SNAPSHOT_PRESETS["session"]) + + +def test_resolve_snapshot_fields_include_adds_to_preset(): + fields = resolve_snapshot_fields({"preset": "session", "include": ["system_prompt"]}) + assert fields == set(SNAPSHOT_PRESETS["session"]) | {"system_prompt"} + + +def test_resolve_snapshot_fields_exclude_removes_from_preset(): + fields = resolve_snapshot_fields({"preset": "session", "exclude": ["messages"]}) + assert "messages" not in fields + + +def test_resolve_snapshot_fields_all_excluded_raises(): + with pytest.raises(SnapshotException): + resolve_snapshot_fields({"exclude": list(ALL_SNAPSHOT_FIELDS)}) # type: ignore[typeddict-item] + + +_ORDERING_CASES = [ + # (preset, include, exclude) + ("session", [], []), + ("session", ["system_prompt"], []), + ("session", [], ["messages"]), + ("session", ["system_prompt"], ["messages", "state"]), + (None, ["messages", "state"], []), + (None, list(ALL_SNAPSHOT_FIELDS), []), + (None, list(ALL_SNAPSHOT_FIELDS), ["system_prompt"]), + ("session", ["system_prompt"], list(SNAPSHOT_PRESETS["session"])), # exclude all preset → only system_prompt +] + + +@pytest.mark.parametrize("preset,include,exclude", _ORDERING_CASES) +def test_resolve_snapshot_fields_ordering(preset, include, exclude): + expected = (set(SNAPSHOT_PRESETS[preset] if preset else []) | set(include)) - set(exclude) + options: TakeSnapshotOptions = {} + if preset is not None: + options["preset"] = preset # type: ignore[assignment] + if include: + options["include"] = include # type: ignore[assignment] + if exclude: + options["exclude"] = exclude # type: ignore[assignment] + + if not expected: + with pytest.raises(SnapshotException): + resolve_snapshot_fields(options) + else: + assert resolve_snapshot_fields(options) == expected + + +_STRUCTURAL_CASES = [ + ([], {}, None), + ([{"role": "user", "content": [{"text": "hi"}]}], {"k": "v"}, "system prompt"), + ([{"role": "user", "content": [{"text": "a"}]}, {"role": "user", "content": [{"text": "b"}]}], {}, None), + ([], {"num": 42, "flag": True}, "another prompt"), +] + + +@pytest.mark.parametrize("messages,state_dict,system_prompt", _STRUCTURAL_CASES) +def test_snapshot_structural_invariants(messages, state_dict, system_prompt): + agent = _make_agent(messages=messages, state=state_dict, system_prompt=system_prompt) + snapshot = agent.take_snapshot(preset="session") + + assert snapshot.schema_version == "1.0" + assert ISO_8601_UTC_RE.match(snapshot.created_at), f"created_at={snapshot.created_at!r} not ISO 8601 UTC" + assert isinstance(snapshot.data, dict) + assert isinstance(snapshot.app_data, dict) + for field in ("messages", "state", "conversation_manager_state", "interrupt_state"): + assert field in snapshot.data + assert "system_prompt" not in snapshot.data + + +_APP_DATA_CASES = [ + {"key": "value"}, + {"num": 42, "flag": True, "nothing": None}, + {"nested_str": "hello", "count": 0}, +] + + +@pytest.mark.parametrize("app_data", _APP_DATA_CASES) +def test_app_data_stored_verbatim(app_data): + agent = _make_agent() + snapshot = agent.take_snapshot(preset="session", app_data=app_data) + assert snapshot.app_data == app_data + + +_ROUND_TRIP_AGENT_CASES = [ + ([], {}), + ([{"role": "user", "content": [{"text": "hi"}]}], {"k": "v"}), + ( + [{"role": "user", "content": [{"text": "a"}]}, {"role": "user", "content": [{"text": "b"}]}], + {"num": 1, "flag": None}, + ), +] + + +@pytest.mark.parametrize("messages,state_dict", _ROUND_TRIP_AGENT_CASES) +def test_agent_state_round_trip(messages, state_dict): + agent = _make_agent(messages=messages, state=state_dict, system_prompt="original prompt") + snapshot = agent.take_snapshot(preset="session") + + fresh_agent = _make_agent(system_prompt="original prompt") + fresh_agent.load_snapshot(snapshot) + + assert fresh_agent.messages == messages + assert fresh_agent.state.get() == state_dict + assert fresh_agent.system_prompt == "original prompt" + assert fresh_agent.conversation_manager.get_state() == agent.conversation_manager.get_state() + assert fresh_agent._interrupt_state.to_dict() == agent._interrupt_state.to_dict() + + +@pytest.mark.parametrize("omitted_field", list(ALL_SNAPSHOT_FIELDS)) +def test_missing_fields_leave_agent_unchanged(omitted_field): + agent = _make_agent( + messages=[{"role": "user", "content": [{"text": "original"}]}], + state={"key": "original"}, + system_prompt="original prompt", + ) + + include_fields = [f for f in ALL_SNAPSHOT_FIELDS if f != omitted_field] + snapshot = agent.take_snapshot(include=include_fields) + # system_prompt field is stored under the key "system_prompt" in snapshot.data + data_key = "system_prompt" if omitted_field == "system_prompt" else omitted_field + assert data_key not in snapshot.data + + fresh_agent = _make_agent( + messages=list(agent.messages), + state=agent.state.get(), + system_prompt="original prompt", + ) + + if omitted_field == "messages": + before = list(fresh_agent.messages) + elif omitted_field == "state": + before = fresh_agent.state.get() + elif omitted_field == "system_prompt": + before = fresh_agent.system_prompt + elif omitted_field == "conversation_manager_state": + before = fresh_agent.conversation_manager.get_state() + elif omitted_field == "interrupt_state": + before = fresh_agent._interrupt_state.to_dict() + else: + pytest.fail(f"Unhandled field in test: {omitted_field!r}. Update this test when adding new snapshot fields.") + + fresh_agent.load_snapshot(snapshot) + + if omitted_field == "messages": + assert fresh_agent.messages == before + elif omitted_field == "state": + assert fresh_agent.state.get() == before + elif omitted_field == "system_prompt": + assert fresh_agent.system_prompt == before + elif omitted_field == "conversation_manager_state": + assert fresh_agent.conversation_manager.get_state() == before + elif omitted_field == "interrupt_state": + assert fresh_agent._interrupt_state.to_dict() == before + else: + pytest.fail(f"Unhandled field in test: {omitted_field!r}. Update this test when adding new snapshot fields.") + + +def test_snapshot_no_system_prompt_clears_target_agent_prompt(): + """Snapshot from agent with no system_prompt (field included) clears prompt on restore.""" + source_agent = _make_agent() # no system_prompt + snapshot = source_agent.take_snapshot(include=["system_prompt"]) + + assert "system_prompt" in snapshot.data + assert snapshot.data["system_prompt"] is None + + target_agent = _make_agent(system_prompt="existing prompt") + target_agent.load_snapshot(snapshot) + + assert target_agent.system_prompt is None + + +def test_snapshot_without_system_prompt_field_preserves_target_agent_prompt(): + """Snapshot taken without system_prompt field does not override target agent's prompt.""" + source_agent = _make_agent(system_prompt="source prompt") + snapshot = source_agent.take_snapshot(include=["messages"]) # system_prompt field excluded + + assert "system_prompt" not in snapshot.data + + target_agent = _make_agent(system_prompt="target prompt") + target_agent.load_snapshot(snapshot) + + assert target_agent.system_prompt == "target prompt" + + +def test_load_snapshot_messages_are_independent_copy(): + """Messages restored from a snapshot are a copy — mutating snapshot.data after load doesn't affect the agent.""" + agent = _make_agent(messages=[{"role": "user", "content": [{"text": "hello"}]}]) + snapshot = agent.take_snapshot(preset="session") + + fresh_agent = _make_agent() + fresh_agent.load_snapshot(snapshot) + + snapshot.data["messages"].append({"role": "user", "content": [{"text": "injected"}]}) + assert len(fresh_agent.messages) == 1 + + +def test_take_snapshot_messages_are_independent_copy(): + """Mutating agent messages after take_snapshot doesn't corrupt the snapshot.""" + msg = {"role": "user", "content": [{"text": "original"}]} + agent = _make_agent(messages=[msg]) + snapshot = agent.take_snapshot(preset="session") + + agent.messages[0]["content"][0]["text"] = "mutated" + assert snapshot.data["messages"][0]["content"][0]["text"] == "original" + + +def test_take_snapshot_app_data_is_independent_copy(): + """Mutating app_data after take_snapshot doesn't corrupt the snapshot.""" + app_data = {"key": "original"} + agent = _make_agent() + snapshot = agent.take_snapshot(preset="session", app_data=app_data) + + app_data["key"] = "mutated" + assert snapshot.app_data["key"] == "original"