From 384f34d51efb2b503bfc19c9115ece29dc7601b9 Mon Sep 17 00:00:00 2001 From: Copilot Date: Fri, 10 Apr 2026 06:18:40 +0000 Subject: [PATCH 1/3] Python: Add allowed_checkpoint_types support to CosmosCheckpointStorage (#5200) Add allowed_checkpoint_types parameter to CosmosCheckpointStorage for parity with FileCheckpointStorage. This ensures both providers use the same restricted pickle deserialization by default. Changes: - Accept allowed_checkpoint_types kwarg in __init__, stored as frozenset - Convert _document_to_checkpoint from @staticmethod to instance method - Forward allowed_types to decode_checkpoint_value on all load paths - Update class docstring to describe the new parameter - Add tests covering built-in safe types, app type opt-in/blocking, and all load paths (load, list_checkpoints, get_latest) - Add changelog entry noting the breaking behavior change BREAKING CHANGE: CosmosCheckpointStorage now uses restricted pickle deserialization by default. Checkpoints containing application-defined types will require passing those types via allowed_checkpoint_types. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- python/CHANGELOG.md | 3 + .../_checkpoint_storage.py | 30 +++- .../tests/test_cosmos_checkpoint_storage.py | 140 ++++++++++++++++++ 3 files changed, 167 insertions(+), 6 deletions(-) diff --git a/python/CHANGELOG.md b/python/CHANGELOG.md index 99947710c9..0ae0df1454 100644 --- a/python/CHANGELOG.md +++ b/python/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed +- **agent-framework-azure-cosmos**: [BREAKING] `CosmosCheckpointStorage` now uses restricted pickle deserialization by default, matching `FileCheckpointStorage` behavior. If your checkpoints contain application-defined types, pass them via `allowed_checkpoint_types=["my_app.models:MyState"]`. ([#5200](https://github.com/microsoft/agent-framework/issues/5200)) + ## [1.0.1] - 2026-04-09 ### Added diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py index 1b6257f203..76406e6ebd 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py @@ -43,9 +43,22 @@ class CosmosCheckpointStorage: ``FileCheckpointStorage``, allowing full Python object fidelity for complex workflow state while keeping the document structure human-readable. - SECURITY WARNING: Checkpoints use pickle for data serialization. Only load - checkpoints from trusted sources. Loading a malicious checkpoint can execute - arbitrary code. + By default, checkpoint deserialization is restricted to a built-in set of safe + Python types (primitives, datetime, uuid, ...) and all ``agent_framework`` + internal types. To allow additional application-specific types, pass them via + the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format. + + Example:: + + storage = CosmosCheckpointStorage( + endpoint="https://my-account.documents.azure.com:443/", + credential=DefaultAzureCredential(), + database_name="agent-db", + container_name="checkpoints", + allowed_checkpoint_types=[ + "my_app.models:MyState", + ], + ) The database and container are created automatically on first use if they do not already exist. The container uses partition key @@ -97,6 +110,7 @@ def __init__( container_client: ContainerProxy | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, + allowed_checkpoint_types: list[str] | None = None, ) -> None: """Initialize the Azure Cosmos DB checkpoint storage. @@ -129,10 +143,15 @@ def __init__( container_client: Pre-created Cosmos container client. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. + allowed_checkpoint_types: Additional types (beyond the built-in safe set + and framework types) that are permitted during checkpoint + deserialization. Each entry should be a ``"module:qualname"`` + string (e.g., ``"my_app.models:MyState"``). """ self._cosmos_client: CosmosClient | None = cosmos_client self._container_proxy: ContainerProxy | None = container_client self._owns_client = False + self._allowed_types: frozenset[str] = frozenset(allowed_checkpoint_types or []) if self._container_proxy is not None: self.database_name: str = database_name or "" @@ -401,8 +420,7 @@ async def _ensure_container_proxy(self) -> None: partition_key=PartitionKey(path="/workflow_name"), ) - @staticmethod - def _document_to_checkpoint(document: dict[str, Any]) -> WorkflowCheckpoint: + def _document_to_checkpoint(self, document: dict[str, Any]) -> WorkflowCheckpoint: """Convert a Cosmos DB document back to a WorkflowCheckpoint. Strips Cosmos DB system properties (``_rid``, ``_self``, ``_etag``, @@ -413,7 +431,7 @@ def _document_to_checkpoint(document: dict[str, Any]) -> WorkflowCheckpoint: cosmos_keys = {"id", "_rid", "_self", "_etag", "_attachments", "_ts"} cleaned = {k: v for k, v in document.items() if k not in cosmos_keys} - decoded = decode_checkpoint_value(cleaned) + decoded = decode_checkpoint_value(cleaned, allowed_types=self._allowed_types) return WorkflowCheckpoint.from_dict(decoded) @staticmethod diff --git a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py index 52155d0e21..e92bf574df 100644 --- a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py +++ b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py @@ -5,6 +5,7 @@ import os import uuid from collections.abc import AsyncIterator +from dataclasses import dataclass from contextlib import suppress from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -595,3 +596,142 @@ async def test_cosmos_checkpoint_storage_roundtrip_with_emulator() -> None: finally: with suppress(Exception): await cosmos_client.delete_database(database_name) + + +# --- Tests for allowed_checkpoint_types --- + + +@dataclass +class _AppState: + """Application-defined state type used to test allowed_checkpoint_types.""" + + label: str + count: int + + +_APP_STATE_TYPE_KEY = f"{_AppState.__module__}:{_AppState.__qualname__}" + + +def _make_checkpoint_with_state(state: dict[str, Any]) -> WorkflowCheckpoint: + """Create a checkpoint with custom state for serialization tests.""" + return WorkflowCheckpoint( + workflow_name="test-workflow", + graph_signature_hash="abc123", + timestamp="2025-01-01T00:00:00+00:00", + state=state, + iteration_count=1, + ) + + +async def test_init_accepts_allowed_checkpoint_types(mock_container: MagicMock) -> None: + """CosmosCheckpointStorage.__init__ accepts allowed_checkpoint_types.""" + storage = CosmosCheckpointStorage( + container_client=mock_container, + allowed_checkpoint_types=["some.module:SomeType"], + ) + assert storage is not None + + +async def test_load_allows_builtin_safe_types(mock_container: MagicMock) -> None: + """Built-in safe types load without opt-in via allowed_checkpoint_types.""" + from datetime import datetime, timezone + + checkpoint = _make_checkpoint_with_state({ + "ts": datetime(2025, 1, 1, tzinfo=timezone.utc), + "tags": {1, 2, 3}, + }) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + loaded = await storage.load(checkpoint.checkpoint_id) + + assert loaded.state["ts"] == datetime(2025, 1, 1, tzinfo=timezone.utc) + assert loaded.state["tags"] == {1, 2, 3} + + +async def test_load_blocks_unlisted_app_type(mock_container: MagicMock) -> None: + """Application types are blocked when not listed in allowed_checkpoint_types.""" + checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)}) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + + with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"): + await storage.load(checkpoint.checkpoint_id) + + +async def test_load_allows_listed_app_type(mock_container: MagicMock) -> None: + """Application types are allowed when listed in allowed_checkpoint_types.""" + checkpoint = _make_checkpoint_with_state({"data": _AppState(label="ok", count=7)}) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage( + container_client=mock_container, + allowed_checkpoint_types=[_APP_STATE_TYPE_KEY], + ) + loaded = await storage.load(checkpoint.checkpoint_id) + + assert isinstance(loaded.state["data"], _AppState) + assert loaded.state["data"].label == "ok" + assert loaded.state["data"].count == 7 + + +async def test_list_checkpoints_blocks_unlisted_app_type(mock_container: MagicMock) -> None: + """list_checkpoints skips documents with unlisted application types.""" + checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)}) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + results = await storage.list_checkpoints(workflow_name="test-workflow") + + # The document is skipped (logged as warning) because the type is blocked + assert len(results) == 0 + + +async def test_list_checkpoints_allows_listed_app_type(mock_container: MagicMock) -> None: + """list_checkpoints decodes documents with listed application types.""" + checkpoint = _make_checkpoint_with_state({"data": _AppState(label="ok", count=3)}) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage( + container_client=mock_container, + allowed_checkpoint_types=[_APP_STATE_TYPE_KEY], + ) + results = await storage.list_checkpoints(workflow_name="test-workflow") + + assert len(results) == 1 + assert isinstance(results[0].state["data"], _AppState) + + +async def test_get_latest_blocks_unlisted_app_type(mock_container: MagicMock) -> None: + """get_latest raises when the checkpoint contains an unlisted application type.""" + checkpoint = _make_checkpoint_with_state({"data": _AppState(label="x", count=1)}) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + + with pytest.raises(WorkflowCheckpointException, match="deserialization blocked"): + await storage.get_latest(workflow_name="test-workflow") + + +async def test_get_latest_allows_listed_app_type(mock_container: MagicMock) -> None: + """get_latest decodes checkpoints with listed application types.""" + checkpoint = _make_checkpoint_with_state({"data": _AppState(label="latest", count=9)}) + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage( + container_client=mock_container, + allowed_checkpoint_types=[_APP_STATE_TYPE_KEY], + ) + result = await storage.get_latest(workflow_name="test-workflow") + + assert result is not None + assert isinstance(result.state["data"], _AppState) + assert result.state["data"].label == "latest" From 9bf8f2888b0ce270f0feda49289f2e545160e7b5 Mon Sep 17 00:00:00 2001 From: Copilot Date: Fri, 10 Apr 2026 06:26:57 +0000 Subject: [PATCH 2/3] Python: Add `allowed_checkpoint_types` support to `CosmosCheckpointStorage` for parity with `FileCheckpointStorage` Fixes #5200 --- .../azure-cosmos/tests/test_cosmos_checkpoint_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py index e92bf574df..016220e693 100644 --- a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py +++ b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py @@ -5,8 +5,8 @@ import os import uuid from collections.abc import AsyncIterator -from dataclasses import dataclass from contextlib import suppress +from dataclasses import dataclass from typing import Any from unittest.mock import AsyncMock, MagicMock, patch From 88c9e99e8429299ab2596da9fe27f08d69bb5966 Mon Sep 17 00:00:00 2001 From: Copilot Date: Fri, 10 Apr 2026 06:32:22 +0000 Subject: [PATCH 3/3] Address PR review: add pickle security warning and fix docstring examples - Reintroduce explicit security warning about pickle deserialization risks - Convert Example:: block to .. code-block:: python with imports for consistency with other docstring examples - Note: PR title should be updated to include [BREAKING] prefix per changelog convention (comment #3, requires GitHub UI change) Fixes #5200 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../_checkpoint_storage.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py index 76406e6ebd..496d95d7c3 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py @@ -43,12 +43,24 @@ class CosmosCheckpointStorage: ``FileCheckpointStorage``, allowing full Python object fidelity for complex workflow state while keeping the document structure human-readable. + Security warning: checkpoints use pickle for non-JSON-native values. Loading + checkpoints from untrusted sources is unsafe and can execute arbitrary code + during deserialization. The built-in deserialization restrictions reduce risk, + but they do not make untrusted checkpoints safe to load. Extending + ``allowed_checkpoint_types`` may further increase risk and should only be done + for trusted application types. + By default, checkpoint deserialization is restricted to a built-in set of safe Python types (primitives, datetime, uuid, ...) and all ``agent_framework`` internal types. To allow additional application-specific types, pass them via the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format. - Example:: + Example: + + .. code-block:: python + + from azure.identity.aio import DefaultAzureCredential + from agent_framework_azure_cosmos import CosmosCheckpointStorage storage = CosmosCheckpointStorage( endpoint="https://my-account.documents.azure.com:443/",