From fc64013800bdbd418979983bac862796143c19ec Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 17:04:21 +0000 Subject: [PATCH 01/29] feat(adk): align memory service with clean-break contract --- docs/extensions/adk/api.rst | 6 - docs/reference/extensions/adk.rst | 4 - sqlspec/extensions/adk/__init__.py | 3 - sqlspec/extensions/adk/memory/__init__.py | 6 +- sqlspec/extensions/adk/memory/converters.py | 16 ++- sqlspec/extensions/adk/memory/service.py | 106 ++-------------- .../test_adk/test_memory_converters.py | 31 ++++- .../test_adk/test_memory_service.py | 113 ++++++++++++++++++ 8 files changed, 168 insertions(+), 117 deletions(-) create mode 100644 tests/unit/extensions/test_adk/test_memory_service.py diff --git a/docs/extensions/adk/api.rst b/docs/extensions/adk/api.rst index e95a28a31..57429736c 100644 --- a/docs/extensions/adk/api.rst +++ b/docs/extensions/adk/api.rst @@ -19,12 +19,6 @@ Services :show-inheritance: :no-index: -.. autoclass:: sqlspec.extensions.adk.memory.SQLSpecSyncMemoryService - :members: - :undoc-members: - :show-inheritance: - :no-index: - .. autoclass:: SQLSpecArtifactService :members: :undoc-members: diff --git a/docs/reference/extensions/adk.rst b/docs/reference/extensions/adk.rst index f404da9b2..5a1dc9c9b 100644 --- a/docs/reference/extensions/adk.rst +++ b/docs/reference/extensions/adk.rst @@ -19,10 +19,6 @@ Memory Services :members: :show-inheritance: -.. autoclass:: sqlspec.extensions.adk.SQLSpecSyncMemoryService - :members: - :show-inheritance: - Artifact Service ================ diff --git a/sqlspec/extensions/adk/__init__.py b/sqlspec/extensions/adk/__init__.py index de3e3ad14..6cf33c24f 100644 --- a/sqlspec/extensions/adk/__init__.py +++ b/sqlspec/extensions/adk/__init__.py @@ -7,7 +7,6 @@ - ADKConfig: TypedDict for extension config (type-safe configuration) - SQLSpecSessionService: Main service class implementing BaseSessionService - SQLSpecMemoryService: Main async service class implementing BaseMemoryService - - SQLSpecSyncMemoryService: Sync memory service for sync adapters - SQLSpecArtifactService: Artifact service implementing BaseArtifactService - BaseAsyncADKStore: Base class for async database store implementations - BaseSyncADKStore: Base class for sync database store implementations @@ -60,7 +59,6 @@ BaseSyncADKMemoryStore, MemoryRecord, SQLSpecMemoryService, - SQLSpecSyncMemoryService, ) from sqlspec.extensions.adk.service import SQLSpecSessionService from sqlspec.extensions.adk.store import BaseAsyncADKStore, BaseSyncADKStore @@ -79,6 +77,5 @@ "SQLSpecArtifactService", "SQLSpecMemoryService", "SQLSpecSessionService", - "SQLSpecSyncMemoryService", "SessionRecord", ) diff --git a/sqlspec/extensions/adk/memory/__init__.py b/sqlspec/extensions/adk/memory/__init__.py index 4522d1af7..5c661a8f5 100644 --- a/sqlspec/extensions/adk/memory/__init__.py +++ b/sqlspec/extensions/adk/memory/__init__.py @@ -6,9 +6,8 @@ Public API exports: - SQLSpecMemoryService: Main async service class implementing BaseMemoryService - - SQLSpecSyncMemoryService: Sync service for sync adapters - BaseAsyncADKMemoryStore: Base class for async database store implementations - - BaseSyncADKMemoryStore: Base class for sync database store implementations + - BaseSyncADKMemoryStore: Internal base for sync stores wrapped behind async APIs - MemoryRecord: TypedDict for memory database records - extract_content_text: Helper to extract searchable text from Content - session_to_memory_records: Convert Session to memory records @@ -54,7 +53,7 @@ record_to_memory_entry, session_to_memory_records, ) -from sqlspec.extensions.adk.memory.service import SQLSpecMemoryService, SQLSpecSyncMemoryService +from sqlspec.extensions.adk.memory.service import SQLSpecMemoryService from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore __all__ = ( @@ -62,7 +61,6 @@ "BaseSyncADKMemoryStore", "MemoryRecord", "SQLSpecMemoryService", - "SQLSpecSyncMemoryService", "extract_content_text", "record_to_memory_entry", "session_to_memory_records", diff --git a/sqlspec/extensions/adk/memory/converters.py b/sqlspec/extensions/adk/memory/converters.py index fafea6d12..37b4c2696 100644 --- a/sqlspec/extensions/adk/memory/converters.py +++ b/sqlspec/extensions/adk/memory/converters.py @@ -10,6 +10,7 @@ from sqlspec.extensions.adk.memory._types import MemoryRecord from sqlspec.utils.logging import get_logger +from sqlspec.utils.serializers import to_json if TYPE_CHECKING: from google.adk.events.event import Event @@ -29,6 +30,15 @@ ) +def _payload_to_search_text(payload: object | None) -> str: + """Serialize structured ADK part payloads into deterministic search text.""" + if payload is None: + return "" + if isinstance(payload, str): + return payload + return to_json(payload) + + def extract_content_text(content: "types.Content") -> str: """Extract plain text from ADK Content for search indexing. @@ -51,9 +61,11 @@ def extract_content_text(content: "types.Content") -> str: if part.text: parts_text.append(part.text) elif part.function_call is not None: - parts_text.append(f"function:{part.function_call.name}") + payload_text = _payload_to_search_text(part.function_call.args) + parts_text.append(f"function:{part.function_call.name} {payload_text}".strip()) elif part.function_response is not None: - parts_text.append(f"response:{part.function_response.name}") + payload_text = _payload_to_search_text(part.function_response.response) + parts_text.append(f"response:{part.function_response.name} {payload_text}".strip()) return " ".join(parts_text) diff --git a/sqlspec/extensions/adk/memory/service.py b/sqlspec/extensions/adk/memory/service.py index 4e977992f..7f106df01 100644 --- a/sqlspec/extensions/adk/memory/service.py +++ b/sqlspec/extensions/adk/memory/service.py @@ -18,11 +18,11 @@ from google.adk.memory.memory_entry import MemoryEntry from google.adk.sessions import Session - from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore + from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore logger = get_logger("sqlspec.extensions.adk.memory.service") -__all__ = ("SQLSpecMemoryService", "SQLSpecSyncMemoryService") +__all__ = ("SQLSpecMemoryService",) class SQLSpecMemoryService(BaseMemoryService): @@ -31,8 +31,10 @@ class SQLSpecMemoryService(BaseMemoryService): Provides memory entry storage using SQLSpec database adapters. Delegates all database operations to a store implementation. - ADK BaseMemoryService defines two core methods: + ADK BaseMemoryService defines the memory write/read contract: - add_session_to_memory(session) - Ingests session into memory (returns void) + - add_events_to_memory(...) - Ingests explicit event deltas + - add_memory(...) - Persists explicit MemoryEntry objects - search_memory(app_name, user_id, query) - Searches stored memories Args: @@ -143,7 +145,10 @@ async def add_events_to_memory( ) if record is not None: if metadata_dict: - record["metadata_json"] = metadata_dict + merged_metadata = dict(metadata_dict) + if record["metadata_json"]: + merged_metadata.update(record["metadata_json"]) + record["metadata_json"] = merged_metadata records.append(record) if not records: @@ -221,96 +226,3 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str) -> "Se logger.debug("Found %d memories for query '%s' (app=%s, user=%s)", len(memories), query[:50], app_name, user_id) return SearchMemoryResponse(memories=memories) - - -class SQLSpecSyncMemoryService: - """Synchronous SQLSpec-backed memory service. - - Provides memory entry storage using SQLSpec sync database adapters. - This is a sync-compatible version for use with sync drivers like SQLite. - - Note: This does NOT inherit from BaseMemoryService since ADK's base class - requires async methods. Use this for sync-only deployments. - - Args: - store: Sync database store implementation. - - Example: - from sqlspec.adapters.sqlite import SqliteConfig - from sqlspec.adapters.sqlite.adk import SqliteADKMemoryStore - from sqlspec.extensions.adk.memory.service import SQLSpecSyncMemoryService - - config = SqliteConfig( - connection_config={"database": "app.db"}, - extension_config={ - "adk": { - "memory_table": "adk_memory_entries", - } - } - ) - store = SqliteADKMemoryStore(config) - store.ensure_tables() - - service = SQLSpecSyncMemoryService(store) - service.add_session_to_memory(completed_session) - - memories = service.search_memory( - app_name="my_app", - user_id="user123", - query="Python discussion" - ) - """ - - def __init__(self, store: "BaseSyncADKMemoryStore") -> None: - """Initialize the sync memory service. - - Args: - store: Sync database store implementation. - """ - self._store = store - - @property - def store(self) -> "BaseSyncADKMemoryStore": - """Return the database store.""" - return self._store - - def add_session_to_memory(self, session: "Session") -> None: - """Add a completed session to the memory store. - - Extracts all events with content from the session and stores them - as searchable memory entries. Uses UPSERT to skip duplicates. - - Args: - session: Completed ADK Session with events. - """ - records = session_to_memory_records(session) - - if not records: - logger.debug( - "No content to store for session %s (app=%s, user=%s)", session.id, session.app_name, session.user_id - ) - return - - inserted_count = self._store.insert_memory_entries(records) - logger.debug( - "Stored %d memory entries for session %s (total events: %d)", inserted_count, session.id, len(records) - ) - - def search_memory(self, *, app_name: str, user_id: str, query: str) -> list["MemoryEntry"]: - """Search memory entries by text query. - - Args: - app_name: Name of the application. - user_id: ID of the user. - query: Text query to search for. - - Returns: - List of MemoryEntry objects. - """ - records = self._store.search_entries(query=query, app_name=app_name, user_id=user_id) - - memories = records_to_memory_entries(records) - - logger.debug("Found %d memories for query '%s' (app=%s, user=%s)", len(memories), query[:50], app_name, user_id) - - return memories diff --git a/tests/unit/extensions/test_adk/test_memory_converters.py b/tests/unit/extensions/test_adk/test_memory_converters.py index e32cda6d5..aae2f43ad 100644 --- a/tests/unit/extensions/test_adk/test_memory_converters.py +++ b/tests/unit/extensions/test_adk/test_memory_converters.py @@ -10,12 +10,14 @@ from google.adk.events.event import Event from google.adk.events.event_actions import EventActions +from google.adk.memory.memory_entry import MemoryEntry from google.adk.sessions.session import Session from google.genai import types from sqlspec.extensions.adk.memory.converters import ( event_to_memory_record, extract_content_text, + memory_entry_to_record, record_to_memory_entry, session_to_memory_records, ) @@ -39,14 +41,16 @@ def test_extract_content_text_combines_parts() -> None: content = types.Content( parts=[ types.Part(text="hello"), - types.Part(function_call=types.FunctionCall(name="lookup")), + types.Part(function_call=types.FunctionCall(name="lookup", args={"sku": "espresso"})), types.Part(function_response=types.FunctionResponse(name="lookup", response={"output": "ok"})), ] ) text = extract_content_text(content) assert "hello" in text assert "function:lookup" in text + assert "espresso" in text assert "response:lookup" in text + assert "ok" in text def test_event_to_memory_record_skips_empty_content() -> None: @@ -67,3 +71,28 @@ def test_session_to_memory_records_roundtrip() -> None: assert entry.content is not None assert entry.content.parts is not None assert entry.content.parts[0].text == "Hello memory" + + +def test_memory_entry_to_record_preserves_identity_and_metadata() -> None: + entry = MemoryEntry( + id="memory-1", + author="agent", + timestamp="2026-05-23T12:00:00+00:00", + content=types.Content(parts=[types.Part(text="Remember espresso roast")]), + custom_metadata={"source": "entry", "priority": 2}, + ) + + record = memory_entry_to_record( + entry=entry, app_name="app", user_id="user", extra_metadata={"source": "call", "ttl": 3600} + ) + + assert record is not None + assert record["id"] == "memory-1" + assert record["author"] == "agent" + assert record["content_text"] == "Remember espresso roast" + assert record["metadata_json"] == {"source": "entry", "ttl": 3600, "priority": 2} + + round_tripped = record_to_memory_entry(record) + assert round_tripped.id == "memory-1" + assert round_tripped.author == "agent" + assert round_tripped.custom_metadata == {"source": "entry", "ttl": 3600, "priority": 2} diff --git a/tests/unit/extensions/test_adk/test_memory_service.py b/tests/unit/extensions/test_adk/test_memory_service.py new file mode 100644 index 000000000..4139b0769 --- /dev/null +++ b/tests/unit/extensions/test_adk/test_memory_service.py @@ -0,0 +1,113 @@ +"""Unit tests for ADK memory service clean-break behavior.""" + +import importlib.util +from datetime import datetime, timezone +from typing import TYPE_CHECKING + +import pytest + +if importlib.util.find_spec("google.genai") is None or importlib.util.find_spec("google.adk") is None: + pytest.skip("google-adk not installed", allow_module_level=True) + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.memory.memory_entry import MemoryEntry +from google.genai import types + +from sqlspec.extensions.adk.memory import SQLSpecMemoryService + +if TYPE_CHECKING: + from sqlspec.extensions.adk.memory import MemoryRecord + + +class _MemoryStore: + def __init__(self) -> None: + self.entries: list[MemoryRecord] = [] + + async def insert_memory_entries(self, entries: list["MemoryRecord"], owner_id: object | None = None) -> int: + self.entries.extend(entries) + return len(entries) + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: int | None = None + ) -> list["MemoryRecord"]: + return [ + entry + for entry in self.entries + if entry["app_name"] == app_name and entry["user_id"] == user_id and query in entry["content_text"] + ] + + +def _event(event_id: str, text: str, custom_metadata: dict[str, object] | None = None) -> Event: + return Event( + id=event_id, + invocation_id="inv-1", + author="user", + content=types.Content(parts=[types.Part(text=text)]), + actions=EventActions(), + timestamp=datetime.now(timezone.utc).timestamp(), + partial=False, + turn_complete=True, + custom_metadata=custom_metadata, + ) + + +async def test_add_events_to_memory_persists_user_scoped_delta_metadata() -> None: + store = _MemoryStore() + service = SQLSpecMemoryService(store) # type: ignore[arg-type] + + await service.add_events_to_memory( + app_name="app", user_id="user", events=[_event("evt-1", "delta memory")], custom_metadata={"ttl": 3600} + ) + + assert len(store.entries) == 1 + record = store.entries[0] + assert record["session_id"] == "" + assert record["event_id"] == "evt-1" + assert record["metadata_json"] == {"ttl": 3600} + + +async def test_add_events_to_memory_merges_event_and_call_metadata() -> None: + store = _MemoryStore() + service = SQLSpecMemoryService(store) # type: ignore[arg-type] + + await service.add_events_to_memory( + app_name="app", + user_id="user", + events=[_event("evt-1", "delta memory", custom_metadata={"source": "event", "priority": 2})], + custom_metadata={"source": "call", "ttl": 3600}, + ) + + assert len(store.entries) == 1 + assert store.entries[0]["metadata_json"] == {"source": "event", "ttl": 3600, "priority": 2} + + +async def test_add_memory_preserves_entry_metadata_and_supports_search() -> None: + store = _MemoryStore() + service = SQLSpecMemoryService(store) # type: ignore[arg-type] + memory = MemoryEntry( + id="memory-1", + author="agent", + timestamp="2026-05-23T12:00:00+00:00", + content=types.Content(parts=[types.Part(text="direct memory")]), + custom_metadata={"source": "entry"}, + ) + + await service.add_memory(app_name="app", user_id="user", memories=[memory], custom_metadata={"ttl": 3600}) + response = await service.search_memory(app_name="app", user_id="user", query="direct") + + assert len(response.memories) == 1 + result = response.memories[0] + assert result.id == "memory-1" + assert result.author == "agent" + assert result.custom_metadata == {"ttl": 3600, "source": "entry"} + + +def test_sync_memory_service_is_not_public_clean_break_surface() -> None: + import sqlspec.extensions.adk as adk + import sqlspec.extensions.adk.memory as memory + + assert "SQLSpecSyncMemoryService" not in adk.__all__ + assert "SQLSpecSyncMemoryService" not in memory.__all__ + assert not hasattr(adk, "SQLSpecSyncMemoryService") + assert not hasattr(memory, "SQLSpecSyncMemoryService") From 5b51e1b29b67b9da7c5221490f865e0934407824 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 17:23:24 +0000 Subject: [PATCH 02/29] feat(adk): centralize identifier validation --- sqlspec/extensions/adk/artifact/store.py | 36 ++----------- sqlspec/extensions/adk/memory/store.py | 32 ++---------- sqlspec/extensions/adk/store.py | 42 ++------------- sqlspec/extensions/events/_store.py | 22 ++++---- sqlspec/extensions/litestar/channels.py | 10 ++-- sqlspec/extensions/litestar/store.py | 37 ++----------- sqlspec/utils/identifiers.py | 56 ++++++++++++++++++++ tests/unit/utils/test_identifiers.py | 66 ++++++++++++++++++++++++ 8 files changed, 152 insertions(+), 149 deletions(-) create mode 100644 sqlspec/utils/identifiers.py create mode 100644 tests/unit/utils/test_identifiers.py diff --git a/sqlspec/extensions/adk/artifact/store.py b/sqlspec/extensions/adk/artifact/store.py index 0ad05724f..a9cd7a353 100644 --- a/sqlspec/extensions/adk/artifact/store.py +++ b/sqlspec/extensions/adk/artifact/store.py @@ -9,12 +9,12 @@ """ import logging -import re from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from sqlspec.extensions.adk._config_utils import _get_adk_artifact_store_config, _get_adk_config_from_extension from sqlspec.observability import resolve_db_system +from sqlspec.utils.identifiers import validate_identifier from sqlspec.utils.logging import get_logger, log_with_context if TYPE_CHECKING: @@ -27,34 +27,6 @@ __all__ = ("BaseAsyncADKArtifactStore", "BaseSyncADKArtifactStore") -VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") -MAX_TABLE_NAME_LENGTH: Final = 63 - - -def _validate_table_name(table_name: str) -> None: - """Validate table name for SQL safety. - - Args: - table_name: Table name to validate. - - Raises: - ValueError: If table name is invalid. - """ - if not table_name: - msg = "Table name cannot be empty" - raise ValueError(msg) - - if len(table_name) > MAX_TABLE_NAME_LENGTH: - msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})" - raise ValueError(msg) - - if not VALID_TABLE_NAME_PATTERN.match(table_name): - msg = ( - f"Invalid table name: {table_name!r}. " - "Must start with letter/underscore and contain only alphanumeric characters and underscores" - ) - raise ValueError(msg) - class BaseAsyncADKArtifactStore(ABC, Generic[ConfigT]): """Base class for async SQLSpec-backed ADK artifact metadata stores. @@ -84,7 +56,7 @@ def __init__(self, config: ConfigT) -> None: self._config = config store_config = _get_adk_artifact_store_config(self._config) self._artifact_table: str = store_config["artifact_table"] - _validate_table_name(self._artifact_table) + validate_identifier(self._artifact_table, label="table name") def _get_adk_config(self) -> "dict[str, Any]": """Extract ADK configuration from extension_config. @@ -239,7 +211,7 @@ def __init__(self, config: ConfigT) -> None: self._config = config store_config = _get_adk_artifact_store_config(self._config) self._artifact_table: str = store_config["artifact_table"] - _validate_table_name(self._artifact_table) + validate_identifier(self._artifact_table, label="table name") def _get_adk_config(self) -> "dict[str, Any]": """Extract ADK configuration from extension_config. diff --git a/sqlspec/extensions/adk/memory/store.py b/sqlspec/extensions/adk/memory/store.py index 8301df2f7..fbe773dfc 100644 --- a/sqlspec/extensions/adk/memory/store.py +++ b/sqlspec/extensions/adk/memory/store.py @@ -7,6 +7,7 @@ from sqlspec.extensions.adk._config_utils import _ADKMemoryStoreConfig, _get_adk_memory_store_config from sqlspec.observability import resolve_db_system +from sqlspec.utils.identifiers import validate_identifier from sqlspec.utils.logging import get_logger, log_with_context if TYPE_CHECKING: @@ -19,9 +20,7 @@ __all__ = ("BaseAsyncADKMemoryStore", "BaseSyncADKMemoryStore") -VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") COLUMN_NAME_PATTERN: Final = re.compile(r"^(\w+)") -MAX_TABLE_NAME_LENGTH: Final = 63 def _parse_owner_id_column(owner_id_column_ddl: str) -> str: @@ -44,31 +43,6 @@ def _parse_owner_id_column(owner_id_column_ddl: str) -> str: return match.group(1) -def _validate_table_name(table_name: str) -> None: - """Validate table name for SQL safety. - - Args: - table_name: Table name to validate. - - Raises: - ValueError: If table name is invalid. - """ - if not table_name: - msg = "Table name cannot be empty" - raise ValueError(msg) - - if len(table_name) > MAX_TABLE_NAME_LENGTH: - msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})" - raise ValueError(msg) - - if not VALID_TABLE_NAME_PATTERN.match(table_name): - msg = ( - f"Invalid table name: {table_name!r}. " - "Must start with letter/underscore and contain only alphanumeric characters and underscores" - ) - raise ValueError(msg) - - class BaseAsyncADKMemoryStore(ABC, Generic[ConfigT]): """Base class for async SQLSpec-backed ADK memory stores. @@ -131,7 +105,7 @@ def __init__(self, config: ConfigT) -> None: self._owner_id_column_name: str | None = ( _parse_owner_id_column(self._owner_id_column_ddl) if self._owner_id_column_ddl else None ) - _validate_table_name(self._memory_table) + validate_identifier(self._memory_table, label="table name") def _get_store_config_from_extension(self) -> "_ADKMemoryStoreConfig": """Extract ADK memory configuration from config.extension_config. @@ -359,7 +333,7 @@ def __init__(self, config: ConfigT) -> None: self._owner_id_column_name: str | None = ( _parse_owner_id_column(self._owner_id_column_ddl) if self._owner_id_column_ddl else None ) - _validate_table_name(self._memory_table) + validate_identifier(self._memory_table, label="table name") def _get_store_config_from_extension(self) -> "_ADKMemoryStoreConfig": """Extract ADK memory configuration from config.extension_config. diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index 04ae2d874..47fe0be64 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -7,6 +7,7 @@ from sqlspec.extensions.adk._config_utils import _get_adk_session_store_config from sqlspec.observability import resolve_db_system +from sqlspec.utils.identifiers import validate_identifier from sqlspec.utils.logging import get_logger, log_with_context if TYPE_CHECKING: @@ -21,9 +22,7 @@ __all__ = ("BaseAsyncADKStore", "BaseSyncADKStore") -VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") COLUMN_NAME_PATTERN: Final = re.compile(r"^(\w+)") -MAX_TABLE_NAME_LENGTH: Final = 63 def _parse_owner_id_column(owner_id_column_ddl: str) -> str: @@ -55,37 +54,6 @@ def _parse_owner_id_column(owner_id_column_ddl: str) -> str: return match.group(1) -def _validate_table_name(table_name: str) -> None: - """Validate table name for SQL safety. - - Args: - table_name: Table name to validate. - - Raises: - ValueError: If table name is invalid. - - Notes: - - Must start with letter or underscore - - Can only contain letters, numbers, and underscores - - Maximum length is 63 characters (PostgreSQL limit) - - Prevents SQL injection in table names - """ - if not table_name: - msg = "Table name cannot be empty" - raise ValueError(msg) - - if len(table_name) > MAX_TABLE_NAME_LENGTH: - msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})" - raise ValueError(msg) - - if not VALID_TABLE_NAME_PATTERN.match(table_name): - msg = ( - f"Invalid table name: {table_name!r}. " - "Must start with letter/underscore and contain only alphanumeric characters and underscores" - ) - raise ValueError(msg) - - class BaseAsyncADKStore(ABC, Generic[ConfigT]): """Base class for async SQLSpec-backed ADK session stores. @@ -133,8 +101,8 @@ def __init__(self, config: ConfigT) -> None: self._owner_id_column_name: str | None = ( _parse_owner_id_column(self._owner_id_column_ddl) if self._owner_id_column_ddl else None ) - _validate_table_name(self._session_table) - _validate_table_name(self._events_table) + validate_identifier(self._session_table, label="table name") + validate_identifier(self._events_table, label="table name") def _get_store_config_from_extension(self) -> "dict[str, Any]": """Extract ADK store configuration from config.extension_config. @@ -383,8 +351,8 @@ def __init__(self, config: ConfigT) -> None: self._owner_id_column_name: str | None = ( _parse_owner_id_column(self._owner_id_column_ddl) if self._owner_id_column_ddl else None ) - _validate_table_name(self._session_table) - _validate_table_name(self._events_table) + validate_identifier(self._session_table, label="table name") + validate_identifier(self._events_table, label="table name") def _get_store_config_from_extension(self) -> "dict[str, Any]": """Extract ADK store configuration from config.extension_config. diff --git a/sqlspec/extensions/events/_store.py b/sqlspec/extensions/events/_store.py index f2df503fe..e53a11fd5 100644 --- a/sqlspec/extensions/events/_store.py +++ b/sqlspec/extensions/events/_store.py @@ -1,10 +1,10 @@ """Base classes for adapter-specific event queue stores.""" -import re from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast from sqlspec.exceptions import EventChannelError +from sqlspec.utils.identifiers import validate_identifier if TYPE_CHECKING: from sqlspec.config import DatabaseConfigProtocol @@ -13,25 +13,23 @@ __all__ = ("BaseEventQueueStore", "normalize_event_channel_name", "normalize_queue_table_name") -_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - def normalize_queue_table_name(name: str) -> str: """Validate schema-qualified identifiers and return normalized name.""" - segments = name.split(".") - for segment in segments: - if not _IDENTIFIER_PATTERN.match(segment): - msg = f"Invalid events table name: {name}" - raise EventChannelError(msg) - return name + try: + return validate_identifier(name, allow_schema_qualifier=True, error_cls=EventChannelError) + except EventChannelError as exc: + msg = f"Invalid events table name: {name}" + raise EventChannelError(msg) from exc def normalize_event_channel_name(name: str) -> str: """Validate event channel identifiers and return normalized name.""" - if not _IDENTIFIER_PATTERN.match(name): + try: + return validate_identifier(name, error_cls=EventChannelError) + except EventChannelError as exc: msg = f"Invalid events channel name: {name}" - raise EventChannelError(msg) - return name + raise EventChannelError(msg) from exc class BaseEventQueueStore(ABC, Generic[ConfigT]): diff --git a/sqlspec/extensions/litestar/channels.py b/sqlspec/extensions/litestar/channels.py index 4c24c8923..af5476d78 100644 --- a/sqlspec/extensions/litestar/channels.py +++ b/sqlspec/extensions/litestar/channels.py @@ -3,11 +3,11 @@ import asyncio import base64 import hashlib -import re from typing import TYPE_CHECKING, Any from litestar.channels.backends.base import ChannelsBackend +from sqlspec.utils.identifiers import validate_identifier from sqlspec.utils.logging import get_logger if TYPE_CHECKING: @@ -17,8 +17,6 @@ logger = get_logger("sqlspec.extensions.litestar.channels") -_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - class SQLSpecChannelsBackend(ChannelsBackend): """A Litestar Channels backend implemented on top of SQLSpec's EventChannel. @@ -36,9 +34,11 @@ class SQLSpecChannelsBackend(ChannelsBackend): def __init__( self, event_channel: "AsyncEventChannel", *, channel_prefix: str = "litestar", poll_interval: float = 0.2 ) -> None: - if not _IDENTIFIER_PATTERN.match(channel_prefix): + try: + validate_identifier(channel_prefix, label="channel_prefix") + except ValueError as exc: msg = f"channel_prefix must be a valid identifier, got: {channel_prefix!r}" - raise ValueError(msg) + raise ValueError(msg) from exc if poll_interval <= 0: msg = "poll_interval must be greater than zero" raise ValueError(msg) diff --git a/sqlspec/extensions/litestar/store.py b/sqlspec/extensions/litestar/store.py index a57ec6477..220d1fea6 100644 --- a/sqlspec/extensions/litestar/store.py +++ b/sqlspec/extensions/litestar/store.py @@ -1,11 +1,11 @@ """Base session store classes for Litestar integration.""" -import re from abc import ABC, abstractmethod from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast from sqlspec.observability import resolve_db_system +from sqlspec.utils.identifiers import validate_identifier from sqlspec.utils.logging import get_logger from sqlspec.utils.type_guards import has_extension_config @@ -22,9 +22,6 @@ __all__ = ("BaseSQLSpecStore",) -VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") -MAX_TABLE_NAME_LENGTH: Final = 63 - class BaseSQLSpecStore(ABC, Generic[ConfigT]): """Base class for SQLSpec-backed Litestar session stores. @@ -73,7 +70,7 @@ def __init__(self, config: ConfigT) -> None: """ self._config = config self._table_name = self._get_table_name_from_config() - self._validate_table_name(self._table_name) + validate_identifier(self._table_name, label="table name") def _get_table_name_from_config(self) -> str: """Extract table name from config.extension_config. @@ -268,31 +265,3 @@ def _value_to_bytes(self, value: "str | bytes") -> bytes: if isinstance(value, str): return value.encode("utf-8") return value - - @staticmethod - def _validate_table_name(table_name: str) -> None: - """Validate table name for SQL safety. - - Args: - table_name: Table name to validate. - - Raises: - ValueError: If table name is invalid. - - Notes: - - Must start with letter or underscore - - Can only contain letters, numbers, and underscores - - Maximum length is 63 characters (PostgreSQL limit) - - Prevents SQL injection in table names - """ - if not table_name: - msg = "Table name cannot be empty" - raise ValueError(msg) - - if len(table_name) > MAX_TABLE_NAME_LENGTH: - msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})" - raise ValueError(msg) - - if not VALID_TABLE_NAME_PATTERN.match(table_name): - msg = f"Invalid table name: {table_name!r}. Must start with letter/underscore and contain only alphanumeric characters and underscores" - raise ValueError(msg) diff --git a/sqlspec/utils/identifiers.py b/sqlspec/utils/identifiers.py new file mode 100644 index 000000000..857ae9b73 --- /dev/null +++ b/sqlspec/utils/identifiers.py @@ -0,0 +1,56 @@ +"""SQL identifier validation helpers.""" + +import re +from typing import Final + +__all__ = ("DEFAULT_MAX_IDENTIFIER_LENGTH", "validate_identifier") + +_IDENTIFIER_PATTERN: Final = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") +DEFAULT_MAX_IDENTIFIER_LENGTH: Final = 63 + + +def validate_identifier( + name: str, + *, + max_length: int = DEFAULT_MAX_IDENTIFIER_LENGTH, + allow_schema_qualifier: bool = False, + error_cls: type[Exception] = ValueError, + label: str = "identifier", +) -> str: + """Validate a SQL identifier and return it unchanged. + + Args: + name: Identifier to validate. + max_length: Maximum length per identifier segment. + allow_schema_qualifier: Whether dotted schema-qualified identifiers are allowed. + error_cls: Exception class raised when validation fails. + label: Domain-specific label to use in error messages. + + Returns: + The validated identifier. + + Raises: + error_cls: If the identifier is empty, too long, schema-qualified when not allowed, or malformed. + """ + label_lower = label.lower() + label_title = label_lower.capitalize() + if not name: + msg = f"{label_title} cannot be empty" + raise error_cls(msg) + + if not allow_schema_qualifier and "." in name: + msg = f"Schema qualifier not allowed for {label_lower}: {name!r}" + raise error_cls(msg) + + segments = name.split(".") if allow_schema_qualifier else [name] + for segment in segments: + if len(segment) > max_length: + msg = f"{label_title} too long: {len(segment)} chars (max {max_length}) in {name!r}" + raise error_cls(msg) + if not _IDENTIFIER_PATTERN.match(segment): + msg = ( + f"Invalid {label_lower}: {name!r}. " + "Must start with letter/underscore and contain only alphanumeric characters and underscores" + ) + raise error_cls(msg) + return name diff --git a/tests/unit/utils/test_identifiers.py b/tests/unit/utils/test_identifiers.py new file mode 100644 index 000000000..30f962ae8 --- /dev/null +++ b/tests/unit/utils/test_identifiers.py @@ -0,0 +1,66 @@ +"""Tests for SQL identifier validation helpers.""" + +import pytest + +from sqlspec.utils.identifiers import DEFAULT_MAX_IDENTIFIER_LENGTH, validate_identifier + + +def test_validate_identifier_returns_valid_name_unchanged() -> None: + """Valid identifiers are returned unchanged.""" + assert validate_identifier("adk_sessions") == "adk_sessions" + + +@pytest.mark.parametrize("name", ["", "1_table", "table-name", "table name", "foo; DROP TABLE x"]) +def test_validate_identifier_rejects_invalid_names(name: str) -> None: + """Invalid identifiers are rejected.""" + with pytest.raises(ValueError): + validate_identifier(name) + + +def test_validate_identifier_rejects_names_longer_than_default_limit() -> None: + """Identifiers longer than the default max length are rejected.""" + name = "a" * (DEFAULT_MAX_IDENTIFIER_LENGTH + 1) + + with pytest.raises(ValueError, match="Identifier too long"): + validate_identifier(name) + + +def test_validate_identifier_rejects_schema_qualifier_by_default() -> None: + """Schema-qualified names are rejected unless explicitly enabled.""" + with pytest.raises(ValueError, match="Schema qualifier not allowed"): + validate_identifier("public.adk_sessions") + + +def test_validate_identifier_accepts_schema_qualified_name_when_enabled() -> None: + """Schema-qualified names are validated segment by segment when enabled.""" + assert validate_identifier("public.adk_sessions", allow_schema_qualifier=True) == "public.adk_sessions" + + +def test_validate_identifier_accepts_multi_segment_qualified_name_when_enabled() -> None: + """Existing event queue behavior accepts multi-segment qualified names.""" + name = "catalog.public.adk_events" + + assert validate_identifier(name, allow_schema_qualifier=True) == name + + +@pytest.mark.parametrize("name", [".adk_sessions", "public.", "public..adk_sessions", "public.1_sessions"]) +def test_validate_identifier_rejects_invalid_schema_qualified_segments(name: str) -> None: + """Every schema-qualified segment must be a valid identifier.""" + with pytest.raises(ValueError, match="Invalid identifier"): + validate_identifier(name, allow_schema_qualifier=True) + + +def test_validate_identifier_uses_custom_error_class() -> None: + """Callers can preserve their existing exception type.""" + + class IdentifierError(Exception): + pass + + with pytest.raises(IdentifierError): + validate_identifier("invalid-name", error_cls=IdentifierError) + + +def test_validate_identifier_uses_custom_label_in_error_messages() -> None: + """Callers can preserve existing domain-specific error messages.""" + with pytest.raises(ValueError, match="Invalid table name"): + validate_identifier("invalid-name", label="table name") From 81557728529200f7ea2b31b036815239572affd5 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 18:34:35 +0000 Subject: [PATCH 03/29] chore: ignore PLW0717 ruff rule --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index fe82b3b83..bd669fab7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -587,6 +587,7 @@ ignore = [ "B903", # class could be a dataclass or named tuple "PLW0603", # Using the global statement to update is discouraged "PLW0108", # Replace lambda expression with a def + "PLW0717", # Too many statements in try clause; temporarily disabled while rule is new "RUF067", # ruff - init should only have import statements ] select = ["ALL"] From 54dc2088deee1cde23cc7d679cc792925b4e7cae Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 18:34:54 +0000 Subject: [PATCH 04/29] feat(adk): align event payload column name --- sqlspec/adapters/adbc/adk/store.py | 46 +++++----- sqlspec/adapters/aiomysql/adk/store.py | 28 +++--- sqlspec/adapters/aiosqlite/adk/store.py | 12 +-- sqlspec/adapters/asyncmy/adk/store.py | 28 +++--- sqlspec/adapters/asyncpg/adk/store.py | 18 ++-- .../adapters/cockroach_asyncpg/adk/store.py | 20 ++-- .../adapters/cockroach_psycopg/adk/store.py | 48 +++++----- sqlspec/adapters/duckdb/adk/store.py | 32 +++---- sqlspec/adapters/mysqlconnector/adk/store.py | 48 +++++----- sqlspec/adapters/oracledb/adk/store.py | 92 +++++++++---------- sqlspec/adapters/psqlpy/adk/store.py | 18 ++-- sqlspec/adapters/psycopg/adk/store.py | 44 ++++----- sqlspec/adapters/pymysql/adk/store.py | 24 ++--- sqlspec/adapters/spanner/adk/store.py | 20 ++-- sqlspec/adapters/sqlite/adk/store.py | 12 +-- sqlspec/extensions/adk/_types.py | 6 +- sqlspec/extensions/adk/converters.py | 10 +- .../adk/test_dialect_integration.py | 4 +- .../extensions/adk/test_dialect_support.py | 18 ++-- .../adbc/extensions/adk/test_edge_cases.py | 8 +- .../extensions/adk/test_event_operations.py | 62 ++++++------- .../aiomysql/extensions/adk/test_store.py | 14 +-- .../aiosqlite/extensions/adk/test_store.py | 8 +- .../asyncmy/extensions/adk/test_store.py | 18 ++-- .../duckdb/extensions/adk/test_store.py | 36 ++++---- .../extensions/adk/test_store.py | 18 ++-- .../extensions/adk/test_oracle_specific.py | 26 +++--- .../spanner/extensions/adk/test_adk_store.py | 10 +- .../sqlite/extensions/adk/test_store.py | 8 +- .../test_oracledb/test_oracle_adk_store.py | 18 ++-- .../adapters/test_psycopg/test_adk_store.py | 6 +- .../adapters/test_spanner/test_adk_store.py | 4 +- .../extensions/test_adk/test_converters.py | 36 ++++---- .../unit/extensions/test_adk/test_service.py | 2 +- .../extensions/test_adk/test_store_config.py | 2 +- 35 files changed, 402 insertions(+), 402 deletions(-) diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index 1b1166f29..4b7b76e81 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -35,14 +35,14 @@ class AdbcADKStore(BaseAsyncADKStore["AdbcConfig"]): transfer across multiple databases (PostgreSQL, SQLite, DuckDB, etc.). Events use the new 5-column contract: session_id, invocation_id, author, - timestamp, and event_json. The full ADK Event payload is stored as a - single JSON blob in event_json using a dialect-appropriate column type + timestamp, and event_data. The full ADK Event payload is stored as a + single JSON blob in event_data using a dialect-appropriate column type (JSONB for PostgreSQL, JSON for DuckDB, VARIANT for Snowflake, TEXT for SQLite and generic fallback). Provides: - Session state management with JSON serialization - - Event history tracking via single event_json blob + - Event history tracking via single event_data blob - Atomic event insert + session state update - Timezone-aware timestamps - Foreign key constraints with cascade delete @@ -69,7 +69,7 @@ class AdbcADKStore(BaseAsyncADKStore["AdbcConfig"]): store.ensure_tables() Notes: - - Dialect-appropriate JSON type for event_json storage + - Dialect-appropriate JSON type for event_data storage - TIMESTAMP for timezone-aware timestamps (driver-dependent precision) - Parameter style: ``?`` universally across ADBC backends - State and JSON fields use to_json/from_json for serialization @@ -306,7 +306,7 @@ def _get_events_ddl_postgresql(self) -> str: SQL to create events table optimized for PostgreSQL. Notes: - Uses JSONB for event_json to enable indexing and query support. + Uses JSONB for event_data to enable indexing and query support. """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( @@ -314,7 +314,7 @@ def _get_events_ddl_postgresql(self) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, + event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -326,7 +326,7 @@ def _get_events_ddl_sqlite(self) -> str: SQL to create events table optimized for SQLite. Notes: - Uses TEXT for event_json (SQLite has no native JSON column type). + Uses TEXT for event_data (SQLite has no native JSON column type). """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( @@ -334,7 +334,7 @@ def _get_events_ddl_sqlite(self) -> str: invocation_id TEXT NOT NULL, author TEXT NOT NULL, timestamp REAL NOT NULL, - event_json TEXT NOT NULL, + event_data TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -346,7 +346,7 @@ def _get_events_ddl_duckdb(self) -> str: SQL to create events table optimized for DuckDB. Notes: - Uses JSON for event_json (DuckDB native JSON type). + Uses JSON for event_data (DuckDB native JSON type). """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( @@ -354,7 +354,7 @@ def _get_events_ddl_duckdb(self) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSON NOT NULL, + event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -366,7 +366,7 @@ def _get_events_ddl_snowflake(self) -> str: SQL to create events table optimized for Snowflake. Notes: - Uses VARIANT for event_json (Snowflake semi-structured type). + Uses VARIANT for event_data (Snowflake semi-structured type). """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( @@ -374,7 +374,7 @@ def _get_events_ddl_snowflake(self) -> str: invocation_id VARCHAR NOT NULL, author VARCHAR NOT NULL, timestamp TIMESTAMP_TZ NOT NULL DEFAULT CURRENT_TIMESTAMP(), - event_json VARIANT NOT NULL, + event_data VARIANT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ) """ @@ -386,7 +386,7 @@ def _get_events_ddl_generic(self) -> str: SQL to create events table using generic types. Notes: - Uses TEXT for event_json (maximum portability). + Uses TEXT for event_data (maximum portability). """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( @@ -394,7 +394,7 @@ def _get_events_ddl_generic(self) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json TEXT NOT NULL, + event_data TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -688,10 +688,10 @@ def _insert_event(self, event_record: "EventRecord") -> None: Args: event_record: Event record to store. """ - event_json = self._serialize_json_field(event_record["event_json"]) + event_data = self._serialize_json_field(event_record["event_data"]) sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (?, ?, ?, ?, ?) """ @@ -705,7 +705,7 @@ def _insert_event(self, event_record: "EventRecord") -> None: event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json, + event_data, ), ) conn.commit() @@ -731,7 +731,7 @@ def _append_event_and_update_state( """ insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (?, ?, ?, ?, ?) """ update_sql = f""" @@ -745,7 +745,7 @@ def _append_event_and_update_state( WHERE id = ? """ state_json = self._serialize_state(state) - event_json = self._serialize_json_field(event_record["event_json"]) + event_data = self._serialize_json_field(event_record["event_data"]) with self._config.provide_connection() as conn: cursor = conn.cursor() @@ -757,7 +757,7 @@ def _append_event_and_update_state( event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json, + event_data, ), ) cursor.execute(update_sql, (state_json, session_id)) @@ -806,7 +806,7 @@ def _get_events( Notes: Uses index on (session_id, timestamp ASC). Returns the 5-column EventRecord (session_id, invocation_id, - author, timestamp, event_json). + author, timestamp, event_data). """ where_clauses = ["session_id = ?"] params: list[Any] = [session_id] @@ -818,7 +818,7 @@ def _get_events( where_clause = " AND ".join(where_clauses) limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -837,7 +837,7 @@ def _get_events( invocation_id=row[1], author=row[2], timestamp=row[3], - event_json=self._deserialize_json_field(row[4]) or {}, + event_data=self._deserialize_json_field(row[4]) or {}, ) for row in rows ] diff --git a/sqlspec/adapters/aiomysql/adk/store.py b/sqlspec/adapters/aiomysql/adk/store.py index e800ba286..5b0c67b53 100644 --- a/sqlspec/adapters/aiomysql/adk/store.py +++ b/sqlspec/adapters/aiomysql/adk/store.py @@ -28,7 +28,7 @@ class AiomysqlADKStore(BaseAsyncADKStore["AiomysqlConfig"]): Implements session and event storage for Google Agent Development Kit using MySQL/MariaDB via the aiomysql driver. Provides: - Session state management with JSON storage - - Full-event JSON storage (single ``event_json`` column) + - Full-event JSON storage (single ``event_data`` column) - Atomic event-append + state-update in one transaction - Microsecond-precision timestamps - Foreign key constraints with cascade delete @@ -115,7 +115,7 @@ async def _get_create_events_table_sql(self) -> str: Post clean-break schema: 5 columns only. - session_id, invocation_id, author: indexed scalars - timestamp: microsecond-precision TIMESTAMP - - event_json: full Event as native JSON + - event_data: full Event as native JSON """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( @@ -123,7 +123,7 @@ async def _get_create_events_table_sql(self) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(128) NOT NULL, timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - event_json JSON NOT NULL, + event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, INDEX idx_{self._events_table}_session (session_id, timestamp ASC) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci @@ -317,14 +317,14 @@ async def append_event(self, event_record: EventRecord) -> None: Args: event_record: Event record with 5 keys (session_id, invocation_id, - author, timestamp, event_json). + author, timestamp, event_data). """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + event_data = event_record["event_data"] + event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -339,7 +339,7 @@ async def append_event(self, event_record: EventRecord) -> None: event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_str, + event_data_str, ), ) await conn.commit() @@ -358,13 +358,13 @@ async def append_event_and_update_state( session_id: Session identifier whose state should be updated. state: Post-append durable state snapshot. """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + event_data = event_record["event_data"] + event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data state_json = to_json(state) insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -391,7 +391,7 @@ async def append_event_and_update_state( event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_str, + event_data_str, ), ) await cursor.execute(update_sql, (state_json, session_id)) @@ -437,7 +437,7 @@ async def get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -457,7 +457,7 @@ async def get_events( invocation_id=row[1], author=row[2], timestamp=row[3], - event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], + event_data=from_json(row[4]) if isinstance(row[4], str) else row[4], ) for row in rows ] diff --git a/sqlspec/adapters/aiosqlite/adk/store.py b/sqlspec/adapters/aiosqlite/adk/store.py index d8012f258..dfc12a9af 100644 --- a/sqlspec/adapters/aiosqlite/adk/store.py +++ b/sqlspec/adapters/aiosqlite/adk/store.py @@ -391,16 +391,16 @@ async def append_event(self, event_record: EventRecord) -> None: Args: event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_json. + author, timestamp, event_data. Notes: Uses Julian Day for timestamp. - event_json dict is serialized to TEXT as event_data column. + event_data dict is serialized to TEXT as event_data column. """ import uuid timestamp_julian = _datetime_to_julian(event_record["timestamp"]) - event_data_json = to_json(event_record["event_json"]) + event_data_json = to_json(event_record["event_data"]) event_id = str(uuid.uuid4()) sql = f""" @@ -442,7 +442,7 @@ async def append_event_and_update_state( import uuid timestamp_julian = _datetime_to_julian(event_record["timestamp"]) - event_data_json = to_json(event_record["event_json"]) + event_data_json = to_json(event_record["event_data"]) now_julian = _datetime_to_julian(datetime.now(timezone.utc)) state_json = to_json(state) event_id = str(uuid.uuid4()) @@ -505,7 +505,7 @@ async def get_events( Notes: Uses index on (session_id, timestamp ASC). - Parses event_data TEXT back to dict for event_json field. + Parses event_data TEXT back to dict for event_data field. """ where_clauses = ["session_id = ?"] params: list[Any] = [session_id] @@ -536,7 +536,7 @@ async def get_events( invocation_id=row[2], author=row[3], timestamp=_julian_to_datetime(row[4]), - event_json=from_json(row[5]) if row[5] else {}, + event_data=from_json(row[5]) if row[5] else {}, ) for row in rows ] diff --git a/sqlspec/adapters/asyncmy/adk/store.py b/sqlspec/adapters/asyncmy/adk/store.py index 4f27af634..43364de36 100644 --- a/sqlspec/adapters/asyncmy/adk/store.py +++ b/sqlspec/adapters/asyncmy/adk/store.py @@ -27,7 +27,7 @@ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]): Implements session and event storage for Google Agent Development Kit using MySQL/MariaDB via the AsyncMy driver. Provides: - Session state management with JSON storage - - Full-event JSON storage (single ``event_json`` column) + - Full-event JSON storage (single ``event_data`` column) - Atomic event-append + state-update in one transaction - Microsecond-precision timestamps - Foreign key constraints with cascade delete @@ -114,7 +114,7 @@ async def _get_create_events_table_sql(self) -> str: Post clean-break schema: 5 columns only. - session_id, invocation_id, author: indexed scalars - timestamp: microsecond-precision TIMESTAMP - - event_json: full Event as native JSON + - event_data: full Event as native JSON """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( @@ -122,7 +122,7 @@ async def _get_create_events_table_sql(self) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(128) NOT NULL, timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - event_json JSON NOT NULL, + event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, INDEX idx_{self._events_table}_session (session_id, timestamp ASC) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci @@ -301,14 +301,14 @@ async def append_event(self, event_record: EventRecord) -> None: Args: event_record: Event record with 5 keys (session_id, invocation_id, - author, timestamp, event_json). + author, timestamp, event_data). """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + event_data = event_record["event_data"] + event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -320,7 +320,7 @@ async def append_event(self, event_record: EventRecord) -> None: event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_str, + event_data_str, ), ) await conn.commit() @@ -339,13 +339,13 @@ async def append_event_and_update_state( session_id: Session identifier whose state should be updated. state: Post-append durable state snapshot. """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + event_data = event_record["event_data"] + event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data state_json = to_json(state) insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -369,7 +369,7 @@ async def append_event_and_update_state( event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_str, + event_data_str, ), ) await cursor.execute(update_sql, (state_json, session_id)) @@ -415,7 +415,7 @@ async def get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -432,7 +432,7 @@ async def get_events( invocation_id=row[1], author=row[2], timestamp=row[3], - event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], + event_data=from_json(row[4]) if isinstance(row[4], str) else row[4], ) for row in rows ] diff --git a/sqlspec/adapters/asyncpg/adk/store.py b/sqlspec/adapters/asyncpg/adk/store.py index c4a7a110c..21bf5e9ae 100644 --- a/sqlspec/adapters/asyncpg/adk/store.py +++ b/sqlspec/adapters/asyncpg/adk/store.py @@ -25,11 +25,11 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]): Implements session and event storage for Google Agent Development Kit using PostgreSQL via asyncpg. Events are stored as a single JSONB blob - (``event_json``) alongside indexed scalar columns for efficient querying. + (``event_data``) alongside indexed scalar columns for efficient querying. Provides: - Session state management with JSONB storage - - Full-fidelity event storage via ``event_json`` JSONB column + - Full-fidelity event storage via ``event_data`` JSONB column - Atomic ``append_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete @@ -79,7 +79,7 @@ async def _get_create_events_table_sql(self) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, + event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) WITH (fillfactor = 80); @@ -195,7 +195,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ @@ -206,7 +206,7 @@ async def append_event(self, event_record: EventRecord) -> None: event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_record["event_json"], + event_record["event_data"], ) async def append_event_and_update_state( @@ -214,7 +214,7 @@ async def append_event_and_update_state( ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ update_sql = f""" @@ -231,7 +231,7 @@ async def append_event_and_update_state( event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_record["event_json"], + event_record["event_data"], ) row = await conn.fetchrow(update_sql, state, session_id) @@ -264,7 +264,7 @@ async def get_events( params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -280,7 +280,7 @@ async def get_events( invocation_id=row["invocation_id"], author=row["author"], timestamp=row["timestamp"], - event_json=row["event_json"], + event_data=row["event_data"], ) for row in rows ] diff --git a/sqlspec/adapters/cockroach_asyncpg/adk/store.py b/sqlspec/adapters/cockroach_asyncpg/adk/store.py index f11b64c61..31182fc0f 100644 --- a/sqlspec/adapters/cockroach_asyncpg/adk/store.py +++ b/sqlspec/adapters/cockroach_asyncpg/adk/store.py @@ -25,7 +25,7 @@ class CockroachAsyncpgADKStore(BaseAsyncADKStore["CockroachAsyncpgConfig"]): Implements session and event storage for Google Agent Development Kit using CockroachDB via asyncpg in PostgreSQL compatibility mode. - Events are stored as a single JSONB blob (``event_json``) alongside + Events are stored as a single JSONB blob (``event_data``) alongside indexed scalar columns for efficient querying. CockroachDB-specific differences from native PostgreSQL: @@ -73,15 +73,15 @@ async def _get_create_events_table_sql(self) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, + event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json - ON {self._events_table} USING GIN (event_json); + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_data + ON {self._events_table} USING GIN (event_data); """ def _get_drop_tables_sql(self) -> "list[str]": @@ -197,7 +197,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ @@ -208,7 +208,7 @@ async def append_event(self, event_record: EventRecord) -> None: event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_record["event_json"], + event_record["event_data"], ) async def append_event_and_update_state( @@ -216,7 +216,7 @@ async def append_event_and_update_state( ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ update_sql = f""" @@ -233,7 +233,7 @@ async def append_event_and_update_state( event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_record["event_json"], + event_record["event_data"], ) row = await conn.fetchrow(update_sql, state, session_id) @@ -266,7 +266,7 @@ async def get_events( params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -284,7 +284,7 @@ async def get_events( invocation_id=row["invocation_id"], author=row["author"], timestamp=row["timestamp"], - event_json=row["event_json"], + event_data=row["event_data"], ) for row in rows ] diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py index 03d77f670..4c45c1b7a 100644 --- a/sqlspec/adapters/cockroach_psycopg/adk/store.py +++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py @@ -66,7 +66,7 @@ class CockroachPsycopgAsyncADKStore(BaseAsyncADKStore["CockroachPsycopgAsyncConf Implements session and event storage for Google Agent Development Kit using CockroachDB via psycopg in PostgreSQL compatibility mode. - Events are stored as a single JSONB blob (``event_json``) alongside + Events are stored as a single JSONB blob (``event_data``) alongside indexed scalar columns for efficient querying. CockroachDB-specific differences from native PostgreSQL: @@ -114,15 +114,15 @@ async def _get_create_events_table_sql(self) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, + event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json - ON {self._events_table} USING GIN (event_json); + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_data + ON {self._events_table} USING GIN (event_data); """ def _get_drop_tables_sql(self) -> "list[str]": @@ -246,12 +246,12 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value async with self._config.provide_connection() as conn, conn.cursor() as cur: await cur.execute( @@ -271,7 +271,7 @@ async def append_event_and_update_state( ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ update_sql = f""" @@ -281,8 +281,8 @@ async def append_event_and_update_state( RETURNING id, app_name, user_id, state, create_time, update_time """ - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value async with self._config.provide_connection() as conn, conn.cursor() as cur: await cur.execute( @@ -328,7 +328,7 @@ async def get_events( params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -345,7 +345,7 @@ async def get_events( invocation_id=row["invocation_id"], author=row["author"], timestamp=row["timestamp"], - event_json=row["event_json"], + event_data=row["event_data"], ) for row in rows ] @@ -358,7 +358,7 @@ class CockroachPsycopgSyncADKStore(BaseAsyncADKStore["CockroachPsycopgSyncConfig Implements session and event storage for Google Agent Development Kit using CockroachDB via psycopg in PostgreSQL compatibility mode (sync). - Events are stored as a single JSONB blob (``event_json``) alongside + Events are stored as a single JSONB blob (``event_data``) alongside indexed scalar columns for efficient querying. CockroachDB-specific differences from native PostgreSQL: @@ -405,15 +405,15 @@ async def _get_create_events_table_sql(self) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, + event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json - ON {self._events_table} USING GIN (event_json); + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_data + ON {self._events_table} USING GIN (event_data); """ def _get_drop_tables_sql(self) -> "list[str]": @@ -565,7 +565,7 @@ def _append_event_and_update_state( ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ update_sql = f""" @@ -575,8 +575,8 @@ def _append_event_and_update_state( RETURNING id, app_name, user_id, state, create_time, update_time """ - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute( @@ -615,12 +615,12 @@ async def append_event_and_update_state( def _insert_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute( @@ -648,7 +648,7 @@ def _get_events( where_clause = " AND ".join(where_clauses) limit_clause = " LIMIT %s" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -667,7 +667,7 @@ def _get_events( invocation_id=row["invocation_id"], author=row["author"], timestamp=row["timestamp"], - event_json=row["event_json"], + event_data=row["event_data"], ) for row in rows ] diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py index 3d34194b0..fb689758e 100644 --- a/sqlspec/adapters/duckdb/adk/store.py +++ b/sqlspec/adapters/duckdb/adk/store.py @@ -41,7 +41,7 @@ class DuckdbADKStore(BaseAsyncADKStore["DuckDBConfig"]): using DuckDB's synchronous driver with async wrappers via ``async_()``. Provides: - Session state management with native JSON type - - Event history with single JSON blob (event_json) plus indexed scalars + - Event history with single JSON blob (event_data) plus indexed scalars - Native TIMESTAMPTZ type support - Manual cascade delete (DuckDB has no FK CASCADE) - Columnar storage for analytical queries @@ -67,9 +67,9 @@ class DuckdbADKStore(BaseAsyncADKStore["DuckDBConfig"]): await store.ensure_tables() Notes: - - Uses DuckDB native JSON type for event_json and state + - Uses DuckDB native JSON type for event_data and state - TIMESTAMPTZ for date/time storage with microsecond precision - - event_json stores the full ADK Event as a single JSON blob + - event_data stores the full ADK Event as a single JSON blob - Columnar storage provides excellent analytical query performance - DuckDB doesn't support CASCADE in foreign keys (manual cascade required) - Optimized for OLAP workloads; for high-concurrency writes use PostgreSQL @@ -131,8 +131,8 @@ async def _get_create_events_table_sql(self) -> str: SQL statement to create adk_events table with indexes. Notes: - - 5-column schema: session_id, invocation_id, author, timestamp, event_json - - event_json stores the full ADK Event as a single JSON blob + - 5-column schema: session_id, invocation_id, author, timestamp, event_data + - event_data stores the full ADK Event as a single JSON blob - No decomposed columns -- eliminates column drift with upstream ADK - Foreign key constraint (DuckDB doesn't support CASCADE) - Index on (session_id, timestamp ASC) for ordered event retrieval @@ -144,7 +144,7 @@ async def _get_create_events_table_sql(self) -> str: invocation_id VARCHAR NOT NULL, author VARCHAR NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSON NOT NULL, + event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); @@ -195,7 +195,7 @@ def __get_create_events_table_sql_sync(self) -> str: invocation_id VARCHAR NOT NULL, author VARCHAR NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSON NOT NULL, + event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); @@ -410,11 +410,11 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" - event_json_str = to_json(event_record["event_json"]) + event_data_str = to_json(event_record["event_data"]) sql = f""" INSERT INTO {self._events_table} - (session_id, invocation_id, author, timestamp, event_json) + (session_id, invocation_id, author, timestamp, event_data) VALUES (?, ?, ?, ?, ?) """ @@ -426,7 +426,7 @@ def _append_event(self, event_record: EventRecord) -> None: event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_str, + event_data_str, ), ) conn.commit() @@ -436,7 +436,7 @@ async def append_event(self, event_record: EventRecord) -> None: Args: event_record: Event record with 5 keys (session_id, invocation_id, - author, timestamp, event_json). + author, timestamp, event_data). """ await async_(self._append_event)(event_record) @@ -446,11 +446,11 @@ def _append_event_and_update_state( """Synchronous implementation of append_event_and_update_state.""" now = datetime.now(timezone.utc) state_json = to_json(state) - event_json_str = to_json(event_record["event_json"]) + event_data_str = to_json(event_record["event_data"]) insert_sql = f""" INSERT INTO {self._events_table} - (session_id, invocation_id, author, timestamp, event_json) + (session_id, invocation_id, author, timestamp, event_data) VALUES (?, ?, ?, ?, ?) """ @@ -469,7 +469,7 @@ def _append_event_and_update_state( event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_str, + event_data_str, ), ) cursor = conn.execute(update_sql, (state_json, now, session_id)) @@ -522,7 +522,7 @@ def _get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -539,7 +539,7 @@ def _get_events( invocation_id=row[1], author=row[2], timestamp=row[3], - event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], + event_data=from_json(row[4]) if isinstance(row[4], str) else row[4], ) for row in rows ] diff --git a/sqlspec/adapters/mysqlconnector/adk/store.py b/sqlspec/adapters/mysqlconnector/adk/store.py index 035f8ef1c..74576eae9 100644 --- a/sqlspec/adapters/mysqlconnector/adk/store.py +++ b/sqlspec/adapters/mysqlconnector/adk/store.py @@ -73,7 +73,7 @@ def _mysql_events_ddl(events_table: str, session_table: str) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(128) NOT NULL, timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - event_json JSON NOT NULL, + event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {session_table}(id) ON DELETE CASCADE, INDEX idx_{events_table}_session (session_id, timestamp ASC) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci @@ -85,7 +85,7 @@ class MysqlConnectorAsyncADKStore(BaseAsyncADKStore["MysqlConnectorAsyncConfig"] Provides: - Session state management with JSON storage - - Full-event JSON storage (single ``event_json`` column) + - Full-event JSON storage (single ``event_data`` column) - Atomic event-append + state-update in one transaction - Microsecond-precision timestamps - Foreign key constraints with cascade delete @@ -256,14 +256,14 @@ async def append_event(self, event_record: EventRecord) -> None: Args: event_record: Event record with 5 keys (session_id, invocation_id, - author, timestamp, event_json). + author, timestamp, event_data). """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + event_data = event_record["event_data"] + event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -277,7 +277,7 @@ async def append_event(self, event_record: EventRecord) -> None: event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_str, + event_data_str, ), ) finally: @@ -298,13 +298,13 @@ async def append_event_and_update_state( session_id: Session identifier whose state should be updated. state: Post-append durable state snapshot. """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + event_data = event_record["event_data"] + event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data state_json = to_json(state) insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -330,7 +330,7 @@ async def append_event_and_update_state( event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_str, + event_data_str, ), ) await cursor.execute(update_sql, (state_json, session_id)) @@ -378,7 +378,7 @@ async def get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -399,7 +399,7 @@ async def get_events( invocation_id=cast("str", row[1]), author=cast("str", row[2]), timestamp=cast("datetime", row[3]), - event_json=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]), + event_data=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]), ) for row in rows ] @@ -414,7 +414,7 @@ class MysqlConnectorSyncADKStore(BaseAsyncADKStore["MysqlConnectorSyncConfig"]): Provides: - Session state management with JSON storage - - Full-event JSON storage (single ``event_json`` column) + - Full-event JSON storage (single ``event_data`` column) - Atomic event-create + state-update in one transaction - Microsecond-precision timestamps - Foreign key constraints with cascade delete @@ -624,13 +624,13 @@ def _append_event_and_update_state( session_id: Session identifier whose state should be updated. state: Post-append durable state snapshot. """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + event_data = event_record["event_data"] + event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data state_json = to_json(state) insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -656,7 +656,7 @@ def _append_event_and_update_state( event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_str, + event_data_str, ), ) cursor.execute(update_sql, (state_json, session_id)) @@ -687,12 +687,12 @@ async def append_event_and_update_state( return await async_(self._append_event_and_update_state)(event_record, session_id, state) def _insert_event(self, event_record: EventRecord) -> None: - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + event_data = event_record["event_data"] + event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -706,7 +706,7 @@ def _insert_event(self, event_record: EventRecord) -> None: event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_str, + event_data_str, ), ) finally: @@ -736,7 +736,7 @@ def _get_events( where_clause = " AND ".join(where_clauses) limit_clause = " LIMIT %s" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -759,7 +759,7 @@ def _get_events( invocation_id=cast("str", row[1]), author=cast("str", row[2]), timestamp=cast("datetime", row[3]), - event_json=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]), + event_data=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]), ) for row in rows ] diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index 3d1dd7243..97b51b9a2 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -83,7 +83,7 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): Implements session and event storage for Google Agent Development Kit using Oracle Database via the python-oracledb async driver. Provides: - Session state management with version-specific JSON storage - - Full-fidelity event storage via ``event_json`` column + - Full-fidelity event storage via ``event_data`` column - Atomic ``append_event_and_update_state`` for durable session mutations - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - Foreign key constraints with cascade delete @@ -94,7 +94,7 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): Notes: - JSON storage type detected based on Oracle version (21c+, 12c+, legacy) - - event_json stored as JSON (21c+) or BLOB (older versions) + - event_data stored as JSON (21c+) or BLOB (older versions) - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - Named parameters using :param_name - State merging handled at application level @@ -229,15 +229,15 @@ async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": return None return await self._deserialize_state(data) - async def _serialize_event_json(self, event_json: Any) -> "str | bytes": - """Serialize event_json to the configured Oracle JSON storage format.""" + async def _serialize_event_data(self, event_data: Any) -> "str | bytes": + """Serialize event_data to the configured Oracle JSON storage format.""" storage_type = await self._detect_json_storage_type() if storage_type == JSONStorageType.JSON_NATIVE: - return to_json(event_json) - return to_json(event_json, as_bytes=True) + return to_json(event_data) + return to_json(event_data, as_bytes=True) - async def _read_event_json(self, data: Any) -> str: - """Read event_json from database, handling LOB types. + async def _read_event_data(self, data: Any) -> str: + """Read event_data from database, handling LOB types. Args: data: Data from database (may be LOB, str, or dict). @@ -325,7 +325,7 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - """Get Oracle CREATE TABLE SQL for events with specified storage type. The events table uses the new 5-column contract: session_id, invocation_id, - author, timestamp, and event_json. The event_json column stores the full + author, timestamp, and event_data. The event_data column stores the full ADK Event as JSON (21c+) or BLOB (older versions). Args: @@ -334,7 +334,7 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - Returns: SQL statement to create adk_events table. """ - event_json_col = _event_json_column_ddl(storage_type) + event_data_col = _event_data_column_ddl(storage_type) table_clauses = _oracle_table_feature_clauses( self._config, "events", @@ -350,7 +350,7 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - invocation_id VARCHAR2(256), author VARCHAR2(256), timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, - {event_json_col}, + {event_data_col}, CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ){table_clauses}'; @@ -651,13 +651,13 @@ async def append_event(self, event_record: EventRecord) -> None: Args: event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_json. + author, timestamp, event_data. """ sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_json + :session_id, :invocation_id, :author, :timestamp, :event_data ) """ @@ -670,7 +670,7 @@ async def append_event(self, event_record: EventRecord) -> None: "invocation_id": event_record["invocation_id"], "author": event_record["author"], "timestamp": event_record["timestamp"], - "event_json": await self._serialize_event_json(event_record["event_json"]), + "event_data": await self._serialize_event_data(event_record["event_data"]), }, ) await conn.commit() @@ -688,16 +688,16 @@ async def append_event_and_update_state( Args: event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_json. + author, timestamp, event_data. session_id: Session identifier whose state should be updated. state: Post-append durable state snapshot (``temp:`` keys already stripped by the service layer). """ insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_json + :session_id, :invocation_id, :author, :timestamp, :event_data ) """ @@ -723,7 +723,7 @@ async def append_event_and_update_state( "invocation_id": event_record["invocation_id"], "author": event_record["author"], "timestamp": event_record["timestamp"], - "event_json": await self._serialize_event_json(event_record["event_json"]), + "event_data": await self._serialize_event_data(event_record["event_data"]), }, ) await cursor.execute(update_sql, {"state": state_data, "id": session_id}) @@ -772,7 +772,7 @@ async def get_events( limit_clause = f" FETCH FIRST {limit} ROWS ONLY" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -790,7 +790,7 @@ async def get_events( invocation_id=_oracle_text_value(row[1]), author=_oracle_text_value(row[2]), timestamp=row[3], - event_json=await self._deserialize_json_field(row[4]) or {}, + event_data=await self._deserialize_json_field(row[4]) or {}, ) for row in rows ] @@ -807,7 +807,7 @@ class OracleSyncADKStore(BaseAsyncADKStore["OracleSyncConfig"]): Implements session and event storage for Google Agent Development Kit using Oracle Database via the python-oracledb synchronous driver. Provides: - Session state management with version-specific JSON storage - - Full-fidelity event storage via ``event_json`` column + - Full-fidelity event storage via ``event_data`` column - Atomic ``create_event_and_update_state`` for durable session mutations - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - Foreign key constraints with cascade delete @@ -818,7 +818,7 @@ class OracleSyncADKStore(BaseAsyncADKStore["OracleSyncConfig"]): Notes: - JSON storage type detected based on Oracle version (21c+, 12c+, legacy) - - event_json stored as JSON (21c+) or BLOB (older versions) + - event_data stored as JSON (21c+) or BLOB (older versions) - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - Named parameters using :param_name - State merging handled at application level @@ -951,15 +951,15 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": return None return self._deserialize_state(data) - def _serialize_event_json(self, event_json: Any) -> "str | bytes": - """Serialize event_json to the configured Oracle JSON storage format.""" + def _serialize_event_data(self, event_data: Any) -> "str | bytes": + """Serialize event_data to the configured Oracle JSON storage format.""" storage_type = self._detect_json_storage_type() if storage_type == JSONStorageType.JSON_NATIVE: - return to_json(event_json) - return to_json(event_json, as_bytes=True) + return to_json(event_data) + return to_json(event_data, as_bytes=True) - def _read_event_json(self, data: Any) -> str: - """Read event_json from database, handling LOB types. + def _read_event_data(self, data: Any) -> str: + """Read event_data from database, handling LOB types. Args: data: Data from database (may be LOB, str, or dict). @@ -1045,7 +1045,7 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - """Get Oracle CREATE TABLE SQL for events with specified storage type. The events table uses the new 5-column contract: session_id, invocation_id, - author, timestamp, and event_json. The event_json column stores the full + author, timestamp, and event_data. The event_data column stores the full ADK Event as JSON (21c+) or BLOB (older versions). Args: @@ -1054,7 +1054,7 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - Returns: SQL statement to create adk_events table. """ - event_json_col = _event_json_column_ddl(storage_type) + event_data_col = _event_data_column_ddl(storage_type) table_clauses = _oracle_table_feature_clauses( self._config, "events", @@ -1070,7 +1070,7 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - invocation_id VARCHAR2(256), author VARCHAR2(256), timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, - {event_json_col}, + {event_data_col}, CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ){table_clauses}'; @@ -1408,16 +1408,16 @@ def _append_event_and_update_state( Args: event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_json. + author, timestamp, event_data. session_id: Session identifier whose state should be updated. state: Post-append durable state snapshot (``temp:`` keys already stripped by the service layer). """ insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_json + :session_id, :invocation_id, :author, :timestamp, :event_data ) """ @@ -1443,7 +1443,7 @@ def _append_event_and_update_state( "invocation_id": event_record["invocation_id"], "author": event_record["author"], "timestamp": event_record["timestamp"], - "event_json": self._serialize_event_json(event_record["event_json"]), + "event_data": self._serialize_event_data(event_record["event_data"]), }, ) cursor.execute(update_sql, {"state": state_data, "id": session_id}) @@ -1495,7 +1495,7 @@ def _get_events( where_clause = " AND ".join(where_clauses) limit_clause = f" FETCH FIRST {limit} ROWS ONLY" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -1513,7 +1513,7 @@ def _get_events( invocation_id=_oracle_text_value(row[1]), author=_oracle_text_value(row[2]), timestamp=row[3], - event_json=self._deserialize_json_field(row[4]) or {}, + event_data=self._deserialize_json_field(row[4]) or {}, ) for row in rows ] @@ -1533,9 +1533,9 @@ def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_json + :session_id, :invocation_id, :author, :timestamp, :event_data ) """ @@ -1548,7 +1548,7 @@ def _append_event(self, event_record: EventRecord) -> None: "invocation_id": event_record["invocation_id"], "author": event_record["author"], "timestamp": event_record["timestamp"], - "event_json": self._serialize_event_json(event_record["event_json"]), + "event_data": self._serialize_event_data(event_record["event_data"]), }, ) conn.commit() @@ -2325,13 +2325,13 @@ def _extract_json_value(data: Any) -> "dict[str, Any]": return from_json(str(data)) # type: ignore[no-any-return] -def _event_json_column_ddl(storage_type: JSONStorageType) -> str: - """Return the DDL fragment for the event_json column.""" +def _event_data_column_ddl(storage_type: JSONStorageType) -> str: + """Return the DDL fragment for the event_data column.""" if storage_type == JSONStorageType.JSON_NATIVE: - return "event_json JSON NOT NULL" + return "event_data JSON NOT NULL" if storage_type == JSONStorageType.BLOB_JSON: - return "event_json BLOB CHECK (event_json IS JSON) NOT NULL" - return "event_json BLOB NOT NULL" + return "event_data BLOB CHECK (event_data IS JSON) NOT NULL" + return "event_data BLOB NOT NULL" def _get_oracle_adk_config(config: Any) -> dict[str, Any]: diff --git a/sqlspec/adapters/psqlpy/adk/store.py b/sqlspec/adapters/psqlpy/adk/store.py index 84155f338..a1b682cef 100644 --- a/sqlspec/adapters/psqlpy/adk/store.py +++ b/sqlspec/adapters/psqlpy/adk/store.py @@ -29,12 +29,12 @@ class PsqlpyADKStore(BaseAsyncADKStore["PsqlpyConfig"]): Implements session and event storage for Google Agent Development Kit using PostgreSQL via the high-performance Rust-based psqlpy driver. - Events are stored as a single JSONB blob (``event_json``) alongside + Events are stored as a single JSONB blob (``event_data``) alongside indexed scalar columns for efficient querying. Provides: - Session state management with JSONB storage - - Full-fidelity event storage via ``event_json`` JSONB column + - Full-fidelity event storage via ``event_data`` JSONB column - Atomic ``append_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete @@ -83,7 +83,7 @@ async def _get_create_events_table_sql(self) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, + event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) WITH (fillfactor = 80); @@ -208,7 +208,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ @@ -220,7 +220,7 @@ async def append_event(self, event_record: EventRecord) -> None: event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_record["event_json"], + event_record["event_data"], ], ) @@ -229,7 +229,7 @@ async def append_event_and_update_state( ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ update_sql = f""" @@ -247,7 +247,7 @@ async def append_event_and_update_state( event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_record["event_json"], + event_record["event_data"], ], ) result = await conn.fetch(update_sql, [state, session_id]) @@ -283,7 +283,7 @@ async def get_events( params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -300,7 +300,7 @@ async def get_events( invocation_id=row["invocation_id"], author=row["author"], timestamp=row["timestamp"], - event_json=row["event_json"], + event_data=row["event_data"], ) for row in rows ] diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py index 54de0715c..cf5135308 100644 --- a/sqlspec/adapters/psycopg/adk/store.py +++ b/sqlspec/adapters/psycopg/adk/store.py @@ -61,12 +61,12 @@ class PsycopgAsyncADKStore(BaseAsyncADKStore["PsycopgAsyncConfig"]): Implements session and event storage for Google Agent Development Kit using PostgreSQL via psycopg3 with native async/await support. - Events are stored as a single JSONB blob (``event_json``) alongside + Events are stored as a single JSONB blob (``event_data``) alongside indexed scalar columns for efficient querying. Provides: - Session state management with JSONB storage - - Full-fidelity event storage via ``event_json`` JSONB column + - Full-fidelity event storage via ``event_data`` JSONB column - Atomic ``append_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete @@ -115,7 +115,7 @@ async def _get_create_events_table_sql(self) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, + event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) WITH (fillfactor = 80); @@ -237,12 +237,12 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: query = pg_sql.SQL(""" INSERT INTO {table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value async with self._config.provide_connection() as conn, conn.cursor() as cur: await cur.execute( @@ -261,7 +261,7 @@ async def append_event_and_update_state( ) -> SessionRecord: insert_query = pg_sql.SQL(""" INSERT INTO {table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) @@ -272,8 +272,8 @@ async def append_event_and_update_state( RETURNING id, app_name, user_id, state, create_time, update_time """).format(table=pg_sql.Identifier(self._session_table)) - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value async with self._config.provide_connection() as conn, conn.cursor() as cur: await cur.execute( @@ -319,7 +319,7 @@ async def get_events( query = pg_sql.SQL( """ - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -341,7 +341,7 @@ async def get_events( invocation_id=row["invocation_id"], author=row["author"], timestamp=row["timestamp"], - event_json=row["event_json"], + event_data=row["event_data"], ) for row in rows ] @@ -354,12 +354,12 @@ class PsycopgSyncADKStore(BaseAsyncADKStore["PsycopgSyncConfig"]): Implements session and event storage for Google Agent Development Kit using PostgreSQL via psycopg3 with synchronous execution. - Events are stored as a single JSONB blob (``event_json``) alongside + Events are stored as a single JSONB blob (``event_data``) alongside indexed scalar columns for efficient querying. Provides: - Session state management with JSONB storage - - Full-fidelity event storage via ``event_json`` JSONB column + - Full-fidelity event storage via ``event_data`` JSONB column - Atomic ``create_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete @@ -408,7 +408,7 @@ async def _get_create_events_table_sql(self) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(256) NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, + event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) WITH (fillfactor = 80); @@ -560,12 +560,12 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis def _insert_event(self, event_record: EventRecord) -> None: insert_query = pg_sql.SQL(""" INSERT INTO {table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute( @@ -585,7 +585,7 @@ def _append_event_and_update_state( ) -> SessionRecord: insert_query = pg_sql.SQL(""" INSERT INTO {table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) @@ -596,8 +596,8 @@ def _append_event_and_update_state( RETURNING id, app_name, user_id, state, create_time, update_time """).format(table=pg_sql.Identifier(self._session_table)) - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute( @@ -649,7 +649,7 @@ def _get_events( query = pg_sql.SQL( """ - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -671,7 +671,7 @@ def _get_events( invocation_id=row["invocation_id"], author=row["author"], timestamp=row["timestamp"], - event_json=row["event_json"], + event_data=row["event_data"], ) for row in rows ] diff --git a/sqlspec/adapters/pymysql/adk/store.py b/sqlspec/adapters/pymysql/adk/store.py index c121ed47c..60e7b095f 100644 --- a/sqlspec/adapters/pymysql/adk/store.py +++ b/sqlspec/adapters/pymysql/adk/store.py @@ -40,7 +40,7 @@ class PyMysqlADKStore(BaseAsyncADKStore["PyMysqlConfig"]): Implements session and event storage for Google Agent Development Kit using MySQL/MariaDB via the PyMySQL sync driver. Provides: - Session state management with JSON storage - - Full-event JSON storage (single ``event_json`` column) + - Full-event JSON storage (single ``event_data`` column) - Atomic event-create + state-update in one transaction - Microsecond-precision timestamps - Foreign key constraints with cascade delete @@ -95,7 +95,7 @@ async def _get_create_events_table_sql(self) -> str: invocation_id VARCHAR(256) NOT NULL, author VARCHAR(128) NOT NULL, timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - event_json JSON NOT NULL, + event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, INDEX idx_{self._events_table}_session (session_id, timestamp ASC) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci @@ -287,13 +287,13 @@ def _append_event_and_update_state( session_id: Session identifier whose state should be updated. state: Post-append durable state snapshot. """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + event_data = event_record["event_data"] + event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data state_json = to_json(state) insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -319,7 +319,7 @@ def _append_event_and_update_state( event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_str, + event_data_str, ), ) cursor.execute(update_sql, (state_json, session_id)) @@ -350,12 +350,12 @@ async def append_event_and_update_state( return await async_(self._append_event_and_update_state)(event_record, session_id, state) def _insert_event(self, event_record: EventRecord) -> None: - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + event_data = event_record["event_data"] + event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + session_id, invocation_id, author, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -369,7 +369,7 @@ def _insert_event(self, event_record: EventRecord) -> None: event_record["invocation_id"], event_record["author"], event_record["timestamp"], - event_json_str, + event_data_str, ), ) finally: @@ -399,7 +399,7 @@ def _get_events( where_clause = " AND ".join(where_clauses) limit_clause = " LIMIT %s" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -422,7 +422,7 @@ def _get_events( invocation_id=row[1], author=row[2], timestamp=row[3], - event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], + event_data=from_json(row[4]) if isinstance(row[4], str) else row[4], ) for row in rows ] diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index a389ef26d..4234ec9ba 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -73,7 +73,7 @@ def _event_param_types(self) -> "dict[str, Any]": "invocation_id": SPANNER_PARAM_TYPES.STRING, "author": SPANNER_PARAM_TYPES.STRING, "timestamp": SPANNER_PARAM_TYPES.TIMESTAMP, - "event_json": json_type, + "event_data": json_type, } def _decode_state(self, raw: Any) -> Any: @@ -239,11 +239,11 @@ def _append_event_and_update_state( "invocation_id": event_record["invocation_id"], "author": event_record["author"], "timestamp": event_record["timestamp"], - "event_json": to_json(event_record["event_json"]), + "event_data": to_json(event_record["event_data"]), } insert_sql = f""" - INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json) - VALUES (@session_id, @invocation_id, @author, @timestamp, @event_json) + INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_data) + VALUES (@session_id, @invocation_id, @author, @timestamp, @event_data) """ json_type = _json_param_type() @@ -279,11 +279,11 @@ def _insert_event(self, event_record: "EventRecord") -> None: "invocation_id": event_record["invocation_id"], "author": event_record["author"], "timestamp": event_record["timestamp"], - "event_json": to_json(event_record["event_json"]), + "event_data": to_json(event_record["event_data"]), } insert_sql = f""" - INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json) - VALUES (@session_id, @invocation_id, @author, @timestamp, @event_json) + INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_data) + VALUES (@session_id, @invocation_id, @author, @timestamp, @event_data) """ self._run_write([(insert_sql, event_params, self._event_param_types())]) @@ -291,7 +291,7 @@ def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT session_id, invocation_id, author, timestamp, event_data FROM {self._events_table} WHERE session_id = @session_id """ @@ -315,7 +315,7 @@ def _get_events( "invocation_id": row[1] or "", "author": row[2] or "", "timestamp": row[3], - "event_json": row[4], + "event_data": row[4], } for row in rows ] @@ -389,7 +389,7 @@ async def _get_create_events_table_sql(self) -> str: invocation_id STRING(256) NOT NULL, author STRING(128) NOT NULL, timestamp TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true), - event_json JSON NOT NULL{shard_column} + event_data JSON NOT NULL{shard_column} ) {pk}{options}{self._events_row_deletion_policy} """ diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py index 137beba74..275cd4858 100644 --- a/sqlspec/adapters/sqlite/adk/store.py +++ b/sqlspec/adapters/sqlite/adk/store.py @@ -425,7 +425,7 @@ async def delete_session(self, session_id: str) -> None: def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" timestamp_julian = _datetime_to_julian(event_record["timestamp"]) - event_data_json = to_json(event_record["event_json"]) + event_data_json = to_json(event_record["event_data"]) sql = f""" INSERT INTO {self._events_table} ( @@ -457,11 +457,11 @@ async def append_event(self, event_record: EventRecord) -> None: Args: event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_json. + author, timestamp, event_data. Notes: Uses Julian Day for timestamp. - event_json dict is serialized to TEXT as event_data column. + event_data dict is serialized to TEXT as event_data column. """ await async_(self._append_event)(event_record) @@ -472,7 +472,7 @@ def _append_event_and_update_state( import uuid timestamp_julian = _datetime_to_julian(event_record["timestamp"]) - event_data_json = to_json(event_record["event_json"]) + event_data_json = to_json(event_record["event_data"]) now_julian = _datetime_to_julian(datetime.now(timezone.utc)) state_json = to_json(state) event_id = str(uuid.uuid4()) @@ -569,7 +569,7 @@ def _get_events( invocation_id=row[2], author=row[3], timestamp=_julian_to_datetime(row[4]), - event_json=from_json(row[5]) if row[5] else {}, + event_data=from_json(row[5]) if row[5] else {}, ) for row in rows ] @@ -593,7 +593,7 @@ async def get_events( Notes: Uses index on (session_id, timestamp ASC). - Parses event_data TEXT back to dict for event_json field. + Parses event_data TEXT back to dict for event_data field. """ return await async_(self._get_events)(session_id, after_timestamp, limit) diff --git a/sqlspec/extensions/adk/_types.py b/sqlspec/extensions/adk/_types.py index 3f11b62f0..2864c437b 100644 --- a/sqlspec/extensions/adk/_types.py +++ b/sqlspec/extensions/adk/_types.py @@ -27,15 +27,15 @@ class SessionRecord(TypedDict): class EventRecord(TypedDict): """Database record for an event. - Stores the full ADK Event as a single JSON blob (``event_json``) alongside + Stores the full ADK Event as a single JSON blob (``event_data``) alongside a small number of indexed scalar columns used for query filtering. This design eliminates column drift with upstream ADK: new Event fields are - automatically captured in ``event_json`` without schema changes. + automatically captured in ``event_data`` without schema changes. """ session_id: str invocation_id: str author: str timestamp: datetime - event_json: "dict[str, Any]" + event_data: "dict[str, Any]" diff --git a/sqlspec/extensions/adk/converters.py b/sqlspec/extensions/adk/converters.py index d25161904..6ebd82e06 100644 --- a/sqlspec/extensions/adk/converters.py +++ b/sqlspec/extensions/adk/converters.py @@ -1,7 +1,7 @@ """Conversion functions between ADK models and database records. Implements full-event JSON storage: the entire Event is serialized via -``Event.model_dump_json(exclude_none=True)`` into a single ``event_json`` +``Event.model_dump_json(exclude_none=True)`` into a single ``event_data`` column, with a small set of indexed scalar columns extracted alongside for query performance. Reconstruction uses ``Event.model_validate_json()``. @@ -106,7 +106,7 @@ def record_to_session(record: SessionRecord, events: "list[EventRecord]") -> "Se def event_to_record(event: "Event", session_id: str) -> EventRecord: """Convert ADK Event to database record using full-event JSON storage. - The entire Event is serialized into ``event_json`` via Pydantic's + The entire Event is serialized into ``event_data`` via Pydantic's ``model_dump_json(exclude_none=True)``. A small number of indexed scalar columns are extracted alongside for query performance. @@ -122,7 +122,7 @@ def event_to_record(event: "Event", session_id: str) -> EventRecord: invocation_id=event.invocation_id, author=event.author, timestamp=datetime.fromtimestamp(event.timestamp, tz=timezone.utc), - event_json=event.model_dump(exclude_none=True, mode="json"), + event_data=event.model_dump(exclude_none=True, mode="json"), ) @@ -130,7 +130,7 @@ def record_to_event(record: "EventRecord") -> "Event": """Convert database record to ADK Event. Reconstruction is lossless: the full Event is restored from - ``event_json`` via ``Event.model_validate_json()``. + ``event_data`` via ``Event.model_validate_json()``. Args: record: Event database record. @@ -138,7 +138,7 @@ def record_to_event(record: "EventRecord") -> "Event": Returns: ADK Event object. """ - return Event.model_validate(record["event_json"]) + return Event.model_validate(record["event_data"]) # --------------------------------------------------------------------------- diff --git a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py b/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py index e20536d83..f7cf9c751 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py @@ -86,14 +86,14 @@ async def test_sqlite_dialect_event_operations(sqlite_store: Any) -> None: "invocation_id": "", "author": "", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-1", "content": content, "app_name": app_name, "user_id": user_id}, + "event_data": {"id": "event-1", "content": content, "app_name": app_name, "user_id": user_id}, } await sqlite_store.append_event(event_record) events = await sqlite_store.get_events(session_id) assert len(events) == 1 retrieved_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) assert retrieved_data["content"] == content diff --git a/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py b/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py index 703d40437..5785879c0 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py @@ -90,12 +90,12 @@ def test_generic_sessions_ddl_contains_text() -> None: def test_postgresql_events_ddl_uses_jsonb() -> None: - """Test PostgreSQL events DDL uses JSONB for event_json.""" + """Test PostgreSQL events DDL uses JSONB for event_data.""" config = AdbcConfig(connection_config={"driver_name": "postgresql", "uri": ":memory:"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_postgresql() # pyright: ignore[reportPrivateUsage] assert "JSONB" in ddl - assert "event_json" in ddl + assert "event_data" in ddl assert "session_id" in ddl assert "invocation_id" in ddl assert "author" in ddl @@ -103,32 +103,32 @@ def test_postgresql_events_ddl_uses_jsonb() -> None: def test_sqlite_events_ddl_uses_text() -> None: - """Test SQLite events DDL uses TEXT for event_json.""" + """Test SQLite events DDL uses TEXT for event_data.""" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": ":memory:"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_sqlite() # pyright: ignore[reportPrivateUsage] assert "TEXT" in ddl - assert "event_json" in ddl + assert "event_data" in ddl assert "session_id" in ddl assert "REAL" in ddl # SQLite uses REAL for timestamps def test_duckdb_events_ddl_uses_json() -> None: - """Test DuckDB events DDL uses JSON type for event_json.""" + """Test DuckDB events DDL uses JSON type for event_data.""" config = AdbcConfig(connection_config={"driver_name": "duckdb", "uri": ":memory:"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_duckdb() # pyright: ignore[reportPrivateUsage] assert "JSON" in ddl - assert "event_json" in ddl + assert "event_data" in ddl def test_snowflake_events_ddl_uses_variant() -> None: - """Test Snowflake events DDL uses VARIANT for event_json.""" + """Test Snowflake events DDL uses VARIANT for event_data.""" config = AdbcConfig(connection_config={"driver_name": "snowflake", "uri": "snowflake://test"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_snowflake() # pyright: ignore[reportPrivateUsage] assert "VARIANT" in ddl - assert "event_json" in ddl + assert "event_data" in ddl async def test_ddl_dispatch_uses_correct_dialect() -> None: @@ -141,7 +141,7 @@ async def test_ddl_dispatch_uses_correct_dialect() -> None: events_ddl = await store._get_create_events_table_sql() # pyright: ignore[reportPrivateUsage] assert "JSONB" in events_ddl - assert "event_json" in events_ddl + assert "event_data" in events_ddl def test_owner_id_column_included_in_sessions_ddl() -> None: diff --git a/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py b/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py index 0e11dd0bb..c8a0be2aa 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py @@ -117,7 +117,7 @@ async def test_unicode_in_fields(adbc_store: Any) -> None: "invocation_id": "", "author": "\u30a2\u30b7\u30b9\u30bf\u30f3\u30c8", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "unicode-event", "content": {"text": "\u3053\u3093\u306b\u3061\u306f"}, "app_name": app_name, @@ -130,7 +130,7 @@ async def test_unicode_in_fields(adbc_store: Any) -> None: assert len(events) == 1 assert events[0]["author"] == "\u30a2\u30b7\u30b9\u30bf\u30f3\u30c8" event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) assert event_data["content"]["text"] == "\u3053\u3093\u306b\u3061\u306f" @@ -203,14 +203,14 @@ async def test_event_with_none_values(adbc_store: Any) -> None: "invocation_id": "", "author": "", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "none-event", "app_name": "app", "user_id": "user"}, + "event_data": {"id": "none-event", "app_name": "app", "user_id": "user"}, } await adbc_store.append_event(event_record) events = await adbc_store.get_events(session_id) assert len(events) == 1 assert events[0]["session_id"] == session_id - assert "event_json" in events[0] + assert "event_data" in events[0] async def test_list_sessions_with_same_user_different_apps(adbc_store: Any) -> None: diff --git a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py index 8e54bf766..d21b612f4 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py @@ -43,7 +43,7 @@ async def test_create_event(adbc_store: Any, session_fixture: Any) -> None: "invocation_id": "", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "event-1", "content": {"message": "Hello"}, "app_name": session_fixture["app_name"], @@ -57,11 +57,11 @@ async def test_create_event(adbc_store: Any, session_fixture: Any) -> None: assert events[0]["session_id"] == session_fixture["session_id"] assert events[0]["author"] == "user" assert events[0]["timestamp"] is not None - assert "event_json" in events[0] + assert "event_data" in events[0] - # Content is stored inside event_json + # Content is stored inside event_data event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) assert event_data["content"] == {"message": "Hello"} @@ -73,7 +73,7 @@ async def test_list_events(adbc_store: Any, session_fixture: Any) -> None: "invocation_id": "", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "event-1", "content": {"seq": 1}, "app_name": session_fixture["app_name"], @@ -85,7 +85,7 @@ async def test_list_events(adbc_store: Any, session_fixture: Any) -> None: "invocation_id": "", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "event-2", "content": {"seq": 2}, "app_name": session_fixture["app_name"], @@ -109,13 +109,13 @@ async def test_list_events_empty(adbc_store: Any, session_fixture: Any) -> None: async def test_event_with_all_fields(adbc_store: Any, session_fixture: Any) -> None: - """Test creating event with all optional fields stored in event_json.""" + """Test creating event with all optional fields stored in event_data.""" event_record: EventRecord = { "session_id": session_fixture["session_id"], "invocation_id": "invocation-123", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "full-event", "content": {"text": "Response"}, "app_name": session_fixture["app_name"], @@ -139,9 +139,9 @@ async def test_event_with_all_fields(adbc_store: Any, session_fixture: Any) -> N assert events[0]["invocation_id"] == "invocation-123" assert events[0]["author"] == "assistant" - # Everything else is in event_json + # Everything else is in event_data event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) assert event_data["content"] == {"text": "Response"} assert event_data["branch"] == "main" @@ -161,7 +161,7 @@ async def test_event_with_minimal_fields(adbc_store: Any, session_fixture: Any) "invocation_id": "", "author": "", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "minimal-event", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"], @@ -172,11 +172,11 @@ async def test_event_with_minimal_fields(adbc_store: Any, session_fixture: Any) events = await adbc_store.get_events(session_fixture["session_id"]) assert len(events) == 1 assert events[0]["session_id"] == session_fixture["session_id"] - assert "event_json" in events[0] + assert "event_data" in events[0] -async def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None: - """Test event JSON field serialization and deserialization via event_json.""" +async def test_event_data_fields(adbc_store: Any, session_fixture: Any) -> None: + """Test event JSON field serialization and deserialization via event_data.""" complex_content = {"nested": {"data": "value"}, "list": [1, 2, 3], "null": None} complex_grounding = {"sources": [{"title": "Doc", "url": "http://example.com"}]} complex_custom = {"metadata": {"version": 1, "tags": ["tag1", "tag2"]}} @@ -186,7 +186,7 @@ async def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None: "invocation_id": "", "author": "", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "json-event", "content": complex_content, "grounding_metadata": complex_grounding, @@ -200,7 +200,7 @@ async def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None: events = await adbc_store.get_events(session_fixture["session_id"]) assert len(events) == 1 event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) assert event_data["content"] == complex_content assert event_data["grounding_metadata"] == complex_grounding @@ -214,7 +214,7 @@ async def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None: "invocation_id": "", "author": "", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-1", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + "event_data": {"id": "event-1", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, } await adbc_store.append_event(ev1) @@ -225,7 +225,7 @@ async def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None: "invocation_id": "", "author": "", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-2", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + "event_data": {"id": "event-2", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, } await adbc_store.append_event(ev2) @@ -236,7 +236,7 @@ async def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None: "invocation_id": "", "author": "", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-3", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + "event_data": {"id": "event-3", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, } await adbc_store.append_event(ev3) @@ -258,14 +258,14 @@ async def test_delete_session_cascades_events(adbc_store: Any, session_fixture: "invocation_id": "", "author": "", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-1", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + "event_data": {"id": "event-1", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, } ev2: EventRecord = { "session_id": session_fixture["session_id"], "invocation_id": "", "author": "", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-2", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, + "event_data": {"id": "event-2", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, } await adbc_store.append_event(ev1) await adbc_store.append_event(ev2) @@ -286,7 +286,7 @@ async def test_event_with_empty_actions(adbc_store: Any, session_fixture: Any) - "invocation_id": "", "author": "", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "empty-actions", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"], @@ -296,11 +296,11 @@ async def test_event_with_empty_actions(adbc_store: Any, session_fixture: Any) - events = await adbc_store.get_events(session_fixture["session_id"]) assert len(events) == 1 - assert "event_json" in events[0] + assert "event_data" in events[0] async def test_event_with_large_content(adbc_store: Any, session_fixture: Any) -> None: - """Test creating event with large content in event_json.""" + """Test creating event with large content in event_data.""" large_content = {"data": "x" * 10000} event_record: EventRecord = { @@ -308,7 +308,7 @@ async def test_event_with_large_content(adbc_store: Any, session_fixture: Any) - "invocation_id": "", "author": "", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "large-content", "content": large_content, "app_name": session_fixture["app_name"], @@ -320,7 +320,7 @@ async def test_event_with_large_content(adbc_store: Any, session_fixture: Any) - events = await adbc_store.get_events(session_fixture["session_id"]) assert len(events) == 1 event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) assert event_data["content"] == large_content @@ -332,7 +332,7 @@ async def test_append_event_preserves_existing_session_state(adbc_store: Any, se "invocation_id": "append-only", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "append-only-event", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"], @@ -355,7 +355,7 @@ async def test_get_events_applies_after_timestamp_and_limit(adbc_store: Any, ses "invocation_id": "", "author": "user", "timestamp": base_time, - "event_json": { + "event_data": { "id": "event-1", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"], @@ -366,7 +366,7 @@ async def test_get_events_applies_after_timestamp_and_limit(adbc_store: Any, ses "invocation_id": "", "author": "assistant", "timestamp": base_time + timedelta(seconds=1), - "event_json": { + "event_data": { "id": "event-2", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"], @@ -377,7 +377,7 @@ async def test_get_events_applies_after_timestamp_and_limit(adbc_store: Any, ses "invocation_id": "", "author": "assistant", "timestamp": base_time + timedelta(seconds=2), - "event_json": { + "event_data": { "id": "event-3", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"], @@ -391,6 +391,6 @@ async def test_get_events_applies_after_timestamp_and_limit(adbc_store: Any, ses filtered_events = await adbc_store.get_events(session_fixture["session_id"], after_timestamp=base_time, limit=1) assert len(filtered_events) == 1 - filtered_event = filtered_events[0]["event_json"] + filtered_event = filtered_events[0]["event_data"] filtered_data = json.loads(filtered_event) if isinstance(filtered_event, str) else filtered_event assert filtered_data["id"] == "event-2" diff --git a/tests/integration/adapters/aiomysql/extensions/adk/test_store.py b/tests/integration/adapters/aiomysql/extensions/adk/test_store.py index 356bc6a73..b13389bd3 100644 --- a/tests/integration/adapters/aiomysql/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiomysql/extensions/adk/test_store.py @@ -59,7 +59,7 @@ async def test_storage_types_verification(aiomysql_adk_store: AiomysqlADKStore) assert "invocation_id" in event_col_names assert "author" in event_col_names assert "timestamp" in event_col_names - assert "event_json" in event_col_names + assert "event_data" in event_col_names timestamp_col = next(col for col in event_columns if col[0] == "timestamp") assert "timestamp(6)" in timestamp_col[2].lower(), "timestamp must be TIMESTAMP(6) for microseconds" @@ -148,7 +148,7 @@ async def test_delete_session_cascade(aiomysql_adk_store: AiomysqlADKStore) -> N "invocation_id": "inv-001", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": {"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}, + "event_data": {"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}, } await aiomysql_adk_store.append_event(event_record) @@ -177,7 +177,7 @@ async def test_append_and_get_events(aiomysql_adk_store: AiomysqlADKStore) -> No "invocation_id": "inv-001", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": {"content": {"text": "Hello", "role": "user"}, "app_name": app_name}, + "event_data": {"content": {"text": "Hello", "role": "user"}, "app_name": app_name}, } event2: EventRecord = { @@ -185,7 +185,7 @@ async def test_append_and_get_events(aiomysql_adk_store: AiomysqlADKStore) -> No "invocation_id": "inv-002", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": {"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}, + "event_data": {"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}, } await aiomysql_adk_store.append_event(event1) @@ -197,10 +197,10 @@ async def test_append_and_get_events(aiomysql_adk_store: AiomysqlADKStore) -> No assert events[0]["author"] == "user" assert events[1]["author"] == "assistant" event0_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) event1_data = ( - json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + json.loads(events[1]["event_data"]) if isinstance(events[1]["event_data"], str) else events[1]["event_data"] ) assert event0_data["content"]["text"] == "Hello" assert event1_data["content"]["text"] == "Hi there" @@ -223,7 +223,7 @@ async def test_timestamp_precision(aiomysql_adk_store: AiomysqlADKStore) -> None "invocation_id": "inv-micro", "author": "system", "timestamp": event_time, - "event_json": {"app_name": app_name}, + "event_data": {"app_name": app_name}, } await aiomysql_adk_store.append_event(event) diff --git a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py index 393f6139d..35d169f96 100644 --- a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py @@ -66,7 +66,7 @@ async def test_aiosqlite_append_event_and_update_state_is_atomic_contract(tmp_pa "invocation_id": "inv-1", "author": "user", "timestamp": datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc), - "event_json": {"id": "event-1", "content": {"parts": [{"text": "hello"}]}}, + "event_data": {"id": "event-1", "content": {"parts": [{"text": "hello"}]}}, } await store.append_event_and_update_state(event, session_id, {"turn": 1}) @@ -77,7 +77,7 @@ async def test_aiosqlite_append_event_and_update_state_is_atomic_contract(tmp_pa assert session["state"] == {"turn": 1} assert len(events) == 1 assert events[0]["invocation_id"] == "inv-1" - assert events[0]["event_json"] == {"id": "event-1", "content": {"parts": [{"text": "hello"}]}} + assert events[0]["event_data"] == {"id": "event-1", "content": {"parts": [{"text": "hello"}]}} finally: await config.close_pool() @@ -110,7 +110,7 @@ async def test_aiosqlite_get_events_filters_by_timestamp_and_limit(tmp_path: Pat "invocation_id": f"inv-{index}", "author": "user", "timestamp": base + timedelta(seconds=index), - "event_json": {"id": f"event-{index}"}, + "event_data": {"id": f"event-{index}"}, } await store.append_event(event) @@ -118,6 +118,6 @@ async def test_aiosqlite_get_events_filters_by_timestamp_and_limit(tmp_path: Pat assert len(events) == 1 assert events[0]["invocation_id"] == "inv-1" - assert events[0]["event_json"] == {"id": "event-1"} + assert events[0]["event_data"] == {"id": "event-1"} finally: await config.close_pool() diff --git a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py index 47722093e..1a8f5a2b3 100644 --- a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py +++ b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py @@ -54,12 +54,12 @@ async def test_storage_types_verification(asyncmy_adk_store: AsyncmyADKStore) -> event_columns = await cursor.fetchall() event_col_names = [col[0] for col in event_columns] - # New 5-column schema: session_id, invocation_id, author, timestamp, event_json + # New 5-column schema: session_id, invocation_id, author, timestamp, event_data assert "session_id" in event_col_names assert "invocation_id" in event_col_names assert "author" in event_col_names assert "timestamp" in event_col_names - assert "event_json" in event_col_names + assert "event_data" in event_col_names timestamp_col = next(col for col in event_columns if col[0] == "timestamp") assert "timestamp(6)" in timestamp_col[2].lower(), "timestamp must be TIMESTAMP(6) for microseconds" @@ -148,7 +148,7 @@ async def test_delete_session_cascade(asyncmy_adk_store: AsyncmyADKStore) -> Non "invocation_id": "inv-001", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": {"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}, + "event_data": {"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}, } await asyncmy_adk_store.append_event(event_record) @@ -177,7 +177,7 @@ async def test_append_and_get_events(asyncmy_adk_store: AsyncmyADKStore) -> None "invocation_id": "inv-001", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": {"content": {"text": "Hello", "role": "user"}, "app_name": app_name}, + "event_data": {"content": {"text": "Hello", "role": "user"}, "app_name": app_name}, } event2: EventRecord = { @@ -185,7 +185,7 @@ async def test_append_and_get_events(asyncmy_adk_store: AsyncmyADKStore) -> None "invocation_id": "inv-002", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": {"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}, + "event_data": {"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}, } await asyncmy_adk_store.append_event(event1) @@ -196,12 +196,12 @@ async def test_append_and_get_events(asyncmy_adk_store: AsyncmyADKStore) -> None assert len(events) == 2 assert events[0]["author"] == "user" assert events[1]["author"] == "assistant" - # Content is inside event_json + # Content is inside event_data event0_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) event1_data = ( - json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + json.loads(events[1]["event_data"]) if isinstance(events[1]["event_data"], str) else events[1]["event_data"] ) assert event0_data["content"]["text"] == "Hello" assert event1_data["content"]["text"] == "Hi there" @@ -224,7 +224,7 @@ async def test_timestamp_precision(asyncmy_adk_store: AsyncmyADKStore) -> None: "invocation_id": "inv-micro", "author": "system", "timestamp": event_time, - "event_json": {"app_name": app_name}, + "event_data": {"app_name": app_name}, } await asyncmy_adk_store.append_event(event) diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index 915593b62..0fb48e0ae 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -146,7 +146,7 @@ async def test_delete_session_cascade_events(duckdb_adk_store: DuckdbADKStore) - "invocation_id": "", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "event-001", "content": {"message": "Hello"}, "app_name": "test-app", @@ -178,7 +178,7 @@ async def test_create_event(duckdb_adk_store: DuckdbADKStore) -> None: "invocation_id": "", "author": "user", "timestamp": timestamp, - "event_json": {"id": "event-002", "content": content, "app_name": "test-app", "user_id": "user-006"}, + "event_data": {"id": "event-002", "content": content, "app_name": "test-app", "user_id": "user-006"}, } await duckdb_adk_store.append_event(event_record) @@ -187,9 +187,9 @@ async def test_create_event(duckdb_adk_store: DuckdbADKStore) -> None: assert events[0]["session_id"] == session_id assert events[0]["author"] == "user" - # Content is stored inside event_json + # Content is stored inside event_data event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) assert event_data["content"] == content @@ -204,14 +204,14 @@ async def test_list_events(duckdb_adk_store: DuckdbADKStore) -> None: "invocation_id": "", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-1", "content": {"message": "First"}, "app_name": "test-app", "user_id": "user-007"}, + "event_data": {"id": "event-1", "content": {"message": "First"}, "app_name": "test-app", "user_id": "user-007"}, } event2: EventRecord = { "session_id": session_id, "invocation_id": "", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "event-2", "content": {"message": "Second"}, "app_name": "test-app", @@ -239,7 +239,7 @@ async def test_list_events_empty(duckdb_adk_store: DuckdbADKStore) -> None: async def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> None: - """Test creating events with optional fields stored in event_json.""" + """Test creating events with optional fields stored in event_data.""" session_id = "session-008" await duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) @@ -248,7 +248,7 @@ async def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> N "invocation_id": "inv-123", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "event-full", "content": {"text": "Response"}, "app_name": "test-app", @@ -269,9 +269,9 @@ async def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> N # The 5-key record has invocation_id as a top-level indexed column assert events[0]["invocation_id"] == "inv-123" - # Other fields are inside event_json + # Other fields are inside event_data event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) assert event_data["branch"] == "main" assert event_data["grounding_metadata"] == {"sources": ["doc1", "doc2"]} @@ -293,21 +293,21 @@ async def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> "invocation_id": "", "author": "", "timestamp": t2, - "event_json": {"id": "event-middle", "app_name": "test-app", "user_id": "user-009"}, + "event_data": {"id": "event-middle", "app_name": "test-app", "user_id": "user-009"}, } ev_last: EventRecord = { "session_id": session_id, "invocation_id": "", "author": "", "timestamp": t3, - "event_json": {"id": "event-last", "app_name": "test-app", "user_id": "user-009"}, + "event_data": {"id": "event-last", "app_name": "test-app", "user_id": "user-009"}, } ev_first: EventRecord = { "session_id": session_id, "invocation_id": "", "author": "", "timestamp": t1, - "event_json": {"id": "event-first", "app_name": "test-app", "user_id": "user-009"}, + "event_data": {"id": "event-first", "app_name": "test-app", "user_id": "user-009"}, } await duckdb_adk_store.append_event(ev_middle) @@ -320,7 +320,7 @@ async def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> # Events should be ordered by timestamp ASC event_ids = [] for e in events: - data = json.loads(e["event_json"]) if isinstance(e["event_json"], str) else e["event_json"] + data = json.loads(e["event_data"]) if isinstance(e["event_data"], str) else e["event_data"] event_ids.append(data["id"]) assert event_ids == ["event-first", "event-middle", "event-last"] @@ -377,8 +377,8 @@ async def test_table_not_found_handling(tmp_path: Path) -> None: db_path.unlink() -async def test_event_json_round_trip(duckdb_adk_store: DuckdbADKStore) -> None: - """Test storing and retrieving event data via event_json.""" +async def test_event_data_round_trip(duckdb_adk_store: DuckdbADKStore) -> None: + """Test storing and retrieving event data via event_data.""" session_id = "session-json-rt" await duckdb_adk_store.create_session(session_id, "test-app", "user-012", {}) @@ -387,14 +387,14 @@ async def test_event_json_round_trip(duckdb_adk_store: DuckdbADKStore) -> None: "invocation_id": "", "author": "system", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-json", "content": {"data": "value"}, "app_name": "test-app", "user_id": "user-012"}, + "event_data": {"id": "event-json", "content": {"data": "value"}, "app_name": "test-app", "user_id": "user-012"}, } await duckdb_adk_store.append_event(event_record) events = await duckdb_adk_store.get_events(session_id) assert len(events) == 1 event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) assert event_data["content"] == {"data": "value"} diff --git a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py index def390f0f..ffe142eef 100644 --- a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py +++ b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py @@ -53,12 +53,12 @@ async def test_storage_types_verification(mysqlconnector_adk_store: MysqlConnect event_columns = await cursor.fetchall() event_col_names = [col[0] for col in event_columns] - # New 5-column schema: session_id, invocation_id, author, timestamp, event_json + # New 5-column schema: session_id, invocation_id, author, timestamp, event_data assert "session_id" in event_col_names assert "invocation_id" in event_col_names assert "author" in event_col_names assert "timestamp" in event_col_names - assert "event_json" in event_col_names + assert "event_data" in event_col_names timestamp_col = next(col for col in event_columns if col[0] == "timestamp") assert "timestamp(6)" in cast("str", timestamp_col[2]).lower() @@ -149,7 +149,7 @@ async def test_delete_session_cascade(mysqlconnector_adk_store: MysqlConnectorAs "invocation_id": "inv-001", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": {"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}, + "event_data": {"content": {"text": "Hello"}, "app_name": app_name, "user_id": user_id}, } await mysqlconnector_adk_store.append_event(event_record) @@ -178,7 +178,7 @@ async def test_append_and_get_events(mysqlconnector_adk_store: MysqlConnectorAsy "invocation_id": "inv-001", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": {"content": {"text": "Hello", "role": "user"}, "app_name": app_name}, + "event_data": {"content": {"text": "Hello", "role": "user"}, "app_name": app_name}, } event2: EventRecord = { @@ -186,7 +186,7 @@ async def test_append_and_get_events(mysqlconnector_adk_store: MysqlConnectorAsy "invocation_id": "inv-002", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": {"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}, + "event_data": {"content": {"text": "Hi there", "role": "assistant"}, "app_name": app_name}, } await mysqlconnector_adk_store.append_event(event1) @@ -197,12 +197,12 @@ async def test_append_and_get_events(mysqlconnector_adk_store: MysqlConnectorAsy assert len(events) == 2 assert events[0]["author"] == "user" assert events[1]["author"] == "assistant" - # Content is inside event_json + # Content is inside event_data event0_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) event1_data = ( - json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + json.loads(events[1]["event_data"]) if isinstance(events[1]["event_data"], str) else events[1]["event_data"] ) assert event0_data["content"]["text"] == "Hello" assert event1_data["content"]["text"] == "Hi there" @@ -224,7 +224,7 @@ async def test_timestamp_precision(mysqlconnector_adk_store: MysqlConnectorAsync "invocation_id": "inv-micro", "author": "system", "timestamp": event_time, - "event_json": {"app_name": app_name}, + "event_data": {"app_name": app_name}, } await mysqlconnector_adk_store.append_event(event) diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py index 1e6cb1f94..09652a152 100644 --- a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py +++ b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py @@ -222,8 +222,8 @@ async def test_state_lob_deserialization(oracle_async_store: "OracleAsyncADKStor assert retrieved["state"]["large_field"] == "x" * 10000 -async def test_event_json_lob_deserialization(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test event_json LOB data is correctly deserialized.""" +async def test_event_data_lob_deserialization(oracle_async_store: "OracleAsyncADKStore") -> None: + """Test event_data LOB data is correctly deserialized.""" session_id = _unique_session_id("event-lob") app_name = "test-app" user_id = "user-123" @@ -244,24 +244,24 @@ async def test_event_json_lob_deserialization(oracle_async_store: "OracleAsyncAD "invocation_id": "", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": event_data, + "event_data": event_data, } await oracle_async_store.append_event(event_record) events = await oracle_async_store.get_events(session_id) assert len(events) == 1 - # event_json contains all the data + # event_data contains all the data retrieved_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) assert retrieved_data["content"] == content assert retrieved_data["grounding_metadata"] == {"sources": ["a" * 1000, "b" * 1000]} assert retrieved_data["custom_metadata"] == {"tags": ["tag1", "tag2"], "priority": "high"} -async def test_event_json_storage(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test event_json blob is correctly stored and retrieved.""" +async def test_event_data_storage(oracle_async_store: "OracleAsyncADKStore") -> None: + """Test event_data blob is correctly stored and retrieved.""" session_id = _unique_session_id("event-json") app_name = "test-app" user_id = "user-123" @@ -275,7 +275,7 @@ async def test_event_json_storage(oracle_async_store: "OracleAsyncADKStore") -> "invocation_id": "", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": event_data, + "event_data": event_data, } await oracle_async_store.append_event(event_record) @@ -283,7 +283,7 @@ async def test_event_json_storage(oracle_async_store: "OracleAsyncADKStore") -> events = await oracle_async_store.get_events(session_id) assert len(events) == 1 retrieved_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) assert retrieved_data == event_data @@ -316,7 +316,7 @@ async def test_event_record_5_column_contract(oracle_async_store: "OracleAsyncAD "invocation_id": "inv-001", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": {"content": {"text": "Hello"}, "partial": True, "turn_complete": False, "interrupted": True}, + "event_data": {"content": {"text": "Hello"}, "partial": True, "turn_complete": False, "interrupted": True}, } await oracle_async_store.append_event(event_record) @@ -328,7 +328,7 @@ async def test_event_record_5_column_contract(oracle_async_store: "OracleAsyncAD assert events[0]["author"] == "assistant" retrieved_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) assert retrieved_data["partial"] is True assert retrieved_data["turn_complete"] is False @@ -336,7 +336,7 @@ async def test_event_record_5_column_contract(oracle_async_store: "OracleAsyncAD async def test_event_with_none_values(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test event with minimal event_json content.""" + """Test event with minimal event_data content.""" session_id = _unique_session_id("none-session") app_name = "test-app" user_id = "user-123" @@ -348,7 +348,7 @@ async def test_event_with_none_values(oracle_async_store: "OracleAsyncADKStore") "invocation_id": "", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": {"app_name": app_name}, + "event_data": {"app_name": app_name}, } await oracle_async_store.append_event(event_record) diff --git a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py index b7cca39f2..5ff1ee911 100644 --- a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py +++ b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py @@ -66,14 +66,14 @@ async def test_create_and_list_events(spanner_adk_store: Any) -> None: "invocation_id": "event-1", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-1", "content": {"msg": "hi"}, "app_name": "app", "user_id": "user"}, + "event_data": {"id": "event-1", "content": {"msg": "hi"}, "app_name": "app", "user_id": "user"}, } event_two: EventRecord = { "session_id": session_id, "invocation_id": "event-2", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-2", "content": {"msg": "ok"}, "app_name": "app", "user_id": "user"}, + "event_data": {"id": "event-2", "content": {"msg": "ok"}, "app_name": "app", "user_id": "user"}, } await spanner_adk_store.append_event(event_one) @@ -84,12 +84,12 @@ async def test_create_and_list_events(spanner_adk_store: Any) -> None: assert events[0]["author"] == "user" assert events[1]["author"] == "assistant" - # Content is inside event_json in the new 5-column schema + # Content is inside event_data in the new 5-column schema event0_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) event1_data = ( - json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + json.loads(events[1]["event_data"]) if isinstance(events[1]["event_data"], str) else events[1]["event_data"] ) assert event0_data["content"] == {"msg": "hi"} assert event1_data["content"] == {"msg": "ok"} diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_store.py index 5f6d4e863..3e31fab35 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_store.py @@ -46,7 +46,7 @@ async def test_sqlite_append_event_and_update_state_is_atomic_contract(tmp_path: "invocation_id": "inv-1", "author": "user", "timestamp": datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc), - "event_json": {"id": "event-1", "content": {"parts": [{"text": "hello"}]}}, + "event_data": {"id": "event-1", "content": {"parts": [{"text": "hello"}]}}, } await store.append_event_and_update_state(event, session_id, {"turn": 1}) @@ -57,7 +57,7 @@ async def test_sqlite_append_event_and_update_state_is_atomic_contract(tmp_path: assert session["state"] == {"turn": 1} assert len(events) == 1 assert events[0]["invocation_id"] == "inv-1" - assert events[0]["event_json"] == {"id": "event-1", "content": {"parts": [{"text": "hello"}]}} + assert events[0]["event_data"] == {"id": "event-1", "content": {"parts": [{"text": "hello"}]}} finally: config.close_pool() @@ -90,7 +90,7 @@ async def test_sqlite_get_events_filters_by_timestamp_and_limit(tmp_path: Path) "invocation_id": f"inv-{index}", "author": "user", "timestamp": base + timedelta(seconds=index), - "event_json": {"id": f"event-{index}"}, + "event_data": {"id": f"event-{index}"}, } await store.append_event(event) @@ -98,6 +98,6 @@ async def test_sqlite_get_events_filters_by_timestamp_and_limit(tmp_path: Path) assert len(events) == 1 assert events[0]["invocation_id"] == "inv-1" - assert events[0]["event_json"] == {"id": "event-1"} + assert events[0]["event_data"] == {"id": "event-1"} finally: config.close_pool() diff --git a/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py b/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py index 38bbca7a8..630713232 100644 --- a/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py +++ b/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py @@ -12,7 +12,7 @@ OracleSyncADKMemoryStore, OracleSyncADKStore, ) -from sqlspec.adapters.oracledb.adk.store import _event_json_column_ddl +from sqlspec.adapters.oracledb.adk.store import _event_data_column_ddl def _mock_config(adk_config: dict[str, object]) -> MagicMock: @@ -61,27 +61,27 @@ def test_oracle_sync_adk_store_deserialize_state_dict_coerces_decimal() -> None: assert result == {"state": 5.0} -def test_oracle_event_json_column_ddl_prefers_blob_over_clob() -> None: - assert _event_json_column_ddl(JSONStorageType.JSON_NATIVE) == "event_json JSON NOT NULL" - assert _event_json_column_ddl(JSONStorageType.BLOB_JSON) == "event_json BLOB CHECK (event_json IS JSON) NOT NULL" - assert _event_json_column_ddl(JSONStorageType.BLOB_PLAIN) == "event_json BLOB NOT NULL" +def test_oracle_event_data_column_ddl_prefers_blob_over_clob() -> None: + assert _event_data_column_ddl(JSONStorageType.JSON_NATIVE) == "event_data JSON NOT NULL" + assert _event_data_column_ddl(JSONStorageType.BLOB_JSON) == "event_data BLOB CHECK (event_data IS JSON) NOT NULL" + assert _event_data_column_ddl(JSONStorageType.BLOB_PLAIN) == "event_data BLOB NOT NULL" -async def test_oracle_async_adk_store_serialize_event_json_uses_blob_for_non_native() -> None: +async def test_oracle_async_adk_store_serialize_event_data_uses_blob_for_non_native() -> None: store = OracleAsyncADKStore.__new__(OracleAsyncADKStore) # type: ignore[call-arg] store._json_storage_type = JSONStorageType.BLOB_JSON # type: ignore[attr-defined] - result = await store._serialize_event_json({"value": 1}) # type: ignore[attr-defined] + result = await store._serialize_event_data({"value": 1}) # type: ignore[attr-defined] assert isinstance(result, bytes) assert b'"value":1' in result -def test_oracle_sync_adk_store_serialize_event_json_uses_blob_for_non_native() -> None: +def test_oracle_sync_adk_store_serialize_event_data_uses_blob_for_non_native() -> None: store = OracleSyncADKStore.__new__(OracleSyncADKStore) # type: ignore[call-arg] store._json_storage_type = JSONStorageType.BLOB_JSON # type: ignore[attr-defined] - result = store._serialize_event_json({"value": 1}) # type: ignore[attr-defined] + result = store._serialize_event_data({"value": 1}) # type: ignore[attr-defined] assert isinstance(result, bytes) assert b'"value":1' in result diff --git a/tests/unit/adapters/test_psycopg/test_adk_store.py b/tests/unit/adapters/test_psycopg/test_adk_store.py index 754fc56cd..6320fac44 100644 --- a/tests/unit/adapters/test_psycopg/test_adk_store.py +++ b/tests/unit/adapters/test_psycopg/test_adk_store.py @@ -76,7 +76,7 @@ def test_sync_append_event_inserts_without_session_update() -> None: "invocation_id": "", "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-1"}, + "event_data": {"id": "event-1"}, } store._append_event(event_record) # type: ignore[arg-type] @@ -97,7 +97,7 @@ def test_sync_get_events_passes_after_timestamp_and_limit() -> None: "invocation_id": "", "author": "assistant", "timestamp": base_time, - "event_json": {"id": "event-2"}, + "event_data": {"id": "event-2"}, } ] store, cursor, _ = _build_store(rows) @@ -107,4 +107,4 @@ def test_sync_get_events_passes_after_timestamp_and_limit() -> None: assert len(cursor.execute_calls) == 1 _, params = cursor.execute_calls[0] assert params == ("session-1", base_time, 1) - assert result[0]["event_json"]["id"] == "event-2" + assert result[0]["event_data"]["id"] == "event-2" diff --git a/tests/unit/adapters/test_spanner/test_adk_store.py b/tests/unit/adapters/test_spanner/test_adk_store.py index 2081df46d..50d604091 100644 --- a/tests/unit/adapters/test_spanner/test_adk_store.py +++ b/tests/unit/adapters/test_spanner/test_adk_store.py @@ -23,7 +23,7 @@ def test_insert_event_preserves_event_record_timestamp() -> None: "invocation_id": "inv-1", "author": "user", "timestamp": timestamp, - "event_json": {"id": "event-1"}, + "event_data": {"id": "event-1"}, } with patch.object(store, "_run_write") as run_write: @@ -45,7 +45,7 @@ async def test_append_event_and_update_state_preserves_event_record_timestamp() "invocation_id": "inv-1", "author": "user", "timestamp": timestamp, - "event_json": {"id": "event-1"}, + "event_data": {"id": "event-1"}, } # Stub the post-write SELECT — the contract requires returning the refreshed record. fake_record = { diff --git a/tests/unit/extensions/test_adk/test_converters.py b/tests/unit/extensions/test_adk/test_converters.py index cf866bef1..b7e87770d 100644 --- a/tests/unit/extensions/test_adk/test_converters.py +++ b/tests/unit/extensions/test_adk/test_converters.py @@ -1,7 +1,7 @@ """Unit tests for ADK session/event converters and scoped state helpers. Tests the NEW contract specified in Chapter 1 of the ADK Clean-Break Overhaul: -- EventRecord has exactly 5 keys (session_id, invocation_id, author, timestamp, event_json) +- EventRecord has exactly 5 keys (session_id, invocation_id, author, timestamp, event_data) - event_to_record takes only (event, session_id), not (event, session_id, app_name, user_id) - record_to_event uses Event.model_validate for full round-trip fidelity - filter_temp_state, split_scoped_state, merge_scoped_state for scoped state handling @@ -221,10 +221,10 @@ def test_merge_scoped_state_does_not_mutate_session_state() -> None: def test_event_to_record_only_5_keys() -> None: - """EventRecord has exactly session_id, invocation_id, author, timestamp, event_json.""" + """EventRecord has exactly session_id, invocation_id, author, timestamp, event_data.""" event = _make_event() record = event_to_record(event, "session-1") - assert set(record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_json"} + assert set(record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_data"} def test_event_to_record_signature_two_args_only() -> None: @@ -250,29 +250,29 @@ def test_event_to_record_indexed_fields_match_event() -> None: assert isinstance(record["timestamp"], datetime) -def test_event_to_record_event_json_matches_model_dump() -> None: - """event_json in the record equals event.model_dump(exclude_none=True, mode='json').""" +def test_event_to_record_event_data_matches_model_dump() -> None: + """event_data in the record equals event.model_dump(exclude_none=True, mode='json').""" event = _make_event(text="hello", state_delta={"key": "val"}, custom_metadata={"foo": "bar"}) record = event_to_record(event, "s1") expected_json = event.model_dump(exclude_none=True, mode="json") - assert record["event_json"] == expected_json + assert record["event_data"] == expected_json -def test_event_to_record_event_json_is_dict() -> None: - """event_json field is a plain dict (not bytes, not string).""" +def test_event_to_record_event_data_is_dict() -> None: + """event_data field is a plain dict (not bytes, not string).""" event = _make_event() record = event_to_record(event, "s1") - assert isinstance(record["event_json"], dict) + assert isinstance(record["event_data"], dict) -def test_event_to_record_actions_in_event_json_is_structured() -> None: - """Actions are stored as structured JSON dict in event_json, not as raw bytes.""" +def test_event_to_record_actions_in_event_data_is_structured() -> None: + """Actions are stored as structured JSON dict in event_data, not as raw bytes.""" event = _make_event(state_delta={"x": "y"}) record = event_to_record(event, "s1") - event_json = record["event_json"] + event_data = record["event_data"] # actions should be a dict in the JSON blob - if "actions" in event_json: - assert isinstance(event_json["actions"], dict) + if "actions" in event_data: + assert isinstance(event_data["actions"], dict) def test_event_to_record_timestamp_is_datetime() -> None: @@ -366,13 +366,13 @@ def test_record_to_event_roundtrip_preserves_timestamp() -> None: assert abs(restored.timestamp - fixed_ts) < 1.0 # within 1 second -def test_record_to_event_ignores_unknown_fields_in_event_json() -> None: - """Unknown event_json fields are ignored by the current ADK Event model.""" +def test_record_to_event_ignores_unknown_fields_in_event_data() -> None: + """Unknown event_data fields are ignored by the current ADK Event model.""" event = _make_event(event_id="extra-fields-evt", author="tool") record = event_to_record(event, "s1") - # Inject hypothetical future ADK field into event_json - record["event_json"]["hypothetical_v3_field"] = "some_value" # type: ignore[index] + # Inject hypothetical future ADK field into event_data + record["event_data"]["hypothetical_v3_field"] = "some_value" # type: ignore[index] restored = record_to_event(record) assert restored.id == "extra-fields-evt" diff --git a/tests/unit/extensions/test_adk/test_service.py b/tests/unit/extensions/test_adk/test_service.py index d479980ce..8a6bab71f 100644 --- a/tests/unit/extensions/test_adk/test_service.py +++ b/tests/unit/extensions/test_adk/test_service.py @@ -272,7 +272,7 @@ async def test_append_event_event_record_has_5_keys() -> None: last_call = store.append_event_and_update_state_calls[-1] event_record = last_call["event_record"] - assert set(event_record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_json"} + assert set(event_record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_data"} @pytest.mark.anyio diff --git a/tests/unit/extensions/test_adk/test_store_config.py b/tests/unit/extensions/test_adk/test_store_config.py index 6d172bbac..414f5d980 100644 --- a/tests/unit/extensions/test_adk/test_store_config.py +++ b/tests/unit/extensions/test_adk/test_store_config.py @@ -70,7 +70,7 @@ def create_event( invocation_id=event_id, author=author or user_id, timestamp=datetime.now(), - event_json=content or {}, + event_data=content or {}, ) def create_event_and_update_state(self, event_record: EventRecord, session_id: str, state: dict[str, Any]) -> None: From f90cd1ecee4227544ee0a9dae3ff40493306e2cc Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 19:01:46 +0000 Subject: [PATCH 05/29] feat(adk): unify session memory store contracts --- docs/extensions/adk/api.rst | 12 - docs/reference/extensions/adk.rst | 8 - sqlspec/cli.py | 21 +- sqlspec/extensions/adk/__init__.py | 17 +- sqlspec/extensions/adk/memory/__init__.py | 6 +- sqlspec/extensions/adk/memory/store.py | 232 +--------------- .../adk/migrations/0001_create_adk_tables.py | 13 +- sqlspec/extensions/adk/store.py | 258 +----------------- .../extensions/test_adk/test_store_config.py | 105 +++---- .../test_adk/test_store_instantiation.py | 9 +- 10 files changed, 84 insertions(+), 597 deletions(-) diff --git a/docs/extensions/adk/api.rst b/docs/extensions/adk/api.rst index 57429736c..1d4aa6cd7 100644 --- a/docs/extensions/adk/api.rst +++ b/docs/extensions/adk/api.rst @@ -34,12 +34,6 @@ Session Stores :show-inheritance: :no-index: -.. autoclass:: BaseSyncADKStore - :members: - :undoc-members: - :show-inheritance: - :no-index: - Memory Stores ============= @@ -49,12 +43,6 @@ Memory Stores :show-inheritance: :no-index: -.. autoclass:: BaseSyncADKMemoryStore - :members: - :undoc-members: - :show-inheritance: - :no-index: - Artifact Stores =============== diff --git a/docs/reference/extensions/adk.rst b/docs/reference/extensions/adk.rst index 5a1dc9c9b..6079e1fa6 100644 --- a/docs/reference/extensions/adk.rst +++ b/docs/reference/extensions/adk.rst @@ -33,10 +33,6 @@ Store Base Classes :members: :show-inheritance: -.. autoclass:: sqlspec.extensions.adk.BaseSyncADKStore - :members: - :show-inheritance: - Memory Store Base Classes ========================= @@ -44,10 +40,6 @@ Memory Store Base Classes :members: :show-inheritance: -.. autoclass:: sqlspec.extensions.adk.BaseSyncADKMemoryStore - :members: - :show-inheritance: - Artifact Store Base Classes =========================== diff --git a/sqlspec/cli.py b/sqlspec/cli.py index 2940d70fb..f4dcb34c3 100644 --- a/sqlspec/cli.py +++ b/sqlspec/cli.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: from rich_click import Group - from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore + from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands __all__ = ("add_migration_commands", "get_sqlspec_group") @@ -247,7 +247,7 @@ def _get_adk_configs( def _get_memory_store_class( config: "AsyncDatabaseConfig[Any, Any, Any] | SyncDatabaseConfig[Any, Any, Any]", - ) -> "type[BaseAsyncADKMemoryStore[Any] | BaseSyncADKMemoryStore[Any]] | None": + ) -> "type[BaseAsyncADKMemoryStore[Any]] | None": config_module = type(config).__module__ config_name = type(config).__name__ @@ -259,7 +259,7 @@ def _get_memory_store_class( store_path = f"sqlspec.adapters.{adapter_name}.adk.store.{store_class_name}" try: - return cast("type[BaseAsyncADKMemoryStore[Any] | BaseSyncADKMemoryStore[Any]]", import_string(store_path)) + return cast("type[BaseAsyncADKMemoryStore[Any]]", import_string(store_path)) except ImportError: return None @@ -1045,13 +1045,8 @@ def cleanup_memory(bind_key: str | None, days: int) -> None: # pyright: ignore[ console.print(f"[yellow]No memory store found for {config_name}; skipping.[/]") continue - if isinstance(cfg, AsyncDatabaseConfig): - async_store = cast("BaseAsyncADKMemoryStore[Any]", store_class(cfg)) - deleted = run_(_cleanup_memory_entries_async)(async_store, days) - console.print(f"[green]✓[/] {config_name}: deleted {deleted} memory entries older than {days} days") - continue - sync_store = cast("BaseSyncADKMemoryStore[Any]", store_class(cfg)) - deleted = sync_store.delete_entries_older_than(days) + store = store_class(cfg) + deleted = run_(_cleanup_memory_entries_async)(store, days) console.print(f"[green]✓[/] {config_name}: deleted {deleted} memory entries older than {days} days") @adk_memory_group.command(name="verify", help="Verify memory table exists and is reachable") @@ -1077,15 +1072,13 @@ def verify_memory(bind_key: str | None) -> None: # pyright: ignore[reportUnused continue try: + store = store_class(cfg) + sql = f"SELECT 1 FROM {store.memory_table} WHERE 1 = 0" if isinstance(cfg, AsyncDatabaseConfig): async_cfg: AsyncDatabaseConfig[Any, Any, Any] = cfg - async_store = cast("BaseAsyncADKMemoryStore[Any]", store_class(async_cfg)) - sql = f"SELECT 1 FROM {async_store.memory_table} WHERE 1 = 0" run_(_verify_memory_table_async)(async_cfg, sql) console.print(f"[green]✓[/] {config_name}: memory table reachable") continue - sync_store = cast("BaseSyncADKMemoryStore[Any]", store_class(cfg)) - sql = f"SELECT 1 FROM {sync_store.memory_table} WHERE 1 = 0" with cfg.provide_session() as driver: driver.execute(sql) console.print(f"[green]✓[/] {config_name}: memory table reachable") diff --git a/sqlspec/extensions/adk/__init__.py b/sqlspec/extensions/adk/__init__.py index 6cf33c24f..02f89af28 100644 --- a/sqlspec/extensions/adk/__init__.py +++ b/sqlspec/extensions/adk/__init__.py @@ -8,10 +8,8 @@ - SQLSpecSessionService: Main service class implementing BaseSessionService - SQLSpecMemoryService: Main async service class implementing BaseMemoryService - SQLSpecArtifactService: Artifact service implementing BaseArtifactService - - BaseAsyncADKStore: Base class for async database store implementations - - BaseSyncADKStore: Base class for sync database store implementations - - BaseAsyncADKMemoryStore: Base class for async memory store implementations - - BaseSyncADKMemoryStore: Base class for sync memory store implementations + - BaseAsyncADKStore: Base class for ADK session store implementations + - BaseAsyncADKMemoryStore: Base class for ADK memory store implementations - BaseAsyncADKArtifactStore: Base class for async artifact metadata stores - BaseSyncADKArtifactStore: Base class for sync artifact metadata stores - SessionRecord: TypedDict for session database records @@ -54,14 +52,9 @@ BaseSyncADKArtifactStore, SQLSpecArtifactService, ) -from sqlspec.extensions.adk.memory import ( - BaseAsyncADKMemoryStore, - BaseSyncADKMemoryStore, - MemoryRecord, - SQLSpecMemoryService, -) +from sqlspec.extensions.adk.memory import BaseAsyncADKMemoryStore, MemoryRecord, SQLSpecMemoryService from sqlspec.extensions.adk.service import SQLSpecSessionService -from sqlspec.extensions.adk.store import BaseAsyncADKStore, BaseSyncADKStore +from sqlspec.extensions.adk.store import BaseAsyncADKStore __all__ = ( "ADKConfig", @@ -70,8 +63,6 @@ "BaseAsyncADKMemoryStore", "BaseAsyncADKStore", "BaseSyncADKArtifactStore", - "BaseSyncADKMemoryStore", - "BaseSyncADKStore", "EventRecord", "MemoryRecord", "SQLSpecArtifactService", diff --git a/sqlspec/extensions/adk/memory/__init__.py b/sqlspec/extensions/adk/memory/__init__.py index 5c661a8f5..70a70a5d2 100644 --- a/sqlspec/extensions/adk/memory/__init__.py +++ b/sqlspec/extensions/adk/memory/__init__.py @@ -6,8 +6,7 @@ Public API exports: - SQLSpecMemoryService: Main async service class implementing BaseMemoryService - - BaseAsyncADKMemoryStore: Base class for async database store implementations - - BaseSyncADKMemoryStore: Internal base for sync stores wrapped behind async APIs + - BaseAsyncADKMemoryStore: Base class for ADK memory store implementations - MemoryRecord: TypedDict for memory database records - extract_content_text: Helper to extract searchable text from Content - session_to_memory_records: Convert Session to memory records @@ -54,11 +53,10 @@ session_to_memory_records, ) from sqlspec.extensions.adk.memory.service import SQLSpecMemoryService -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore __all__ = ( "BaseAsyncADKMemoryStore", - "BaseSyncADKMemoryStore", "MemoryRecord", "SQLSpecMemoryService", "extract_content_text", diff --git a/sqlspec/extensions/adk/memory/store.py b/sqlspec/extensions/adk/memory/store.py index fbe773dfc..f7a7f3091 100644 --- a/sqlspec/extensions/adk/memory/store.py +++ b/sqlspec/extensions/adk/memory/store.py @@ -1,4 +1,4 @@ -"""Base store classes for ADK memory backend (sync and async).""" +"""Base store class for ADK memory backends.""" import logging import re @@ -18,7 +18,7 @@ logger = get_logger("sqlspec.extensions.adk.memory.store") -__all__ = ("BaseAsyncADKMemoryStore", "BaseSyncADKMemoryStore") +__all__ = ("BaseAsyncADKMemoryStore",) COLUMN_NAME_PATTERN: Final = re.compile(r"^(\w+)") @@ -269,231 +269,3 @@ def _get_drop_memory_table_sql(self) -> "list[str]": List of SQL statements to drop the memory table and indexes. """ raise NotImplementedError - - -class BaseSyncADKMemoryStore(ABC, Generic[ConfigT]): - """Base class for sync SQLSpec-backed ADK memory stores. - - Implements storage operations for Google ADK memory entries using - SQLSpec database adapters with synchronous execution. - - This abstract base class provides common functionality for sync database-specific - memory store implementations including: - - Connection management via SQLSpec configs - - Table name validation - - Memory entry CRUD operations - - Text search with optional full-text search support - - Subclasses must implement dialect-specific SQL queries and will be created - in each adapter directory (e.g., sqlspec/adapters/sqlite/adk/store.py). - - Args: - config: SQLSpec database configuration with extension_config["adk"] settings. - - Notes: - Configuration is read from config.extension_config["adk"]: - - memory_table: Memory table name (default: "adk_memory_entries") - - memory_use_fts: Enable full-text search when supported (default: False) - - memory_max_results: Max search results (default: 20) - - owner_id_column: Optional owner FK column DDL (default: None) - - enable_memory: Whether memory is enabled (default: True) - """ - - __slots__ = ( - "_config", - "_enabled", - "_max_results", - "_memory_table", - "_owner_id_column_ddl", - "_owner_id_column_name", - "_use_fts", - ) - - def __init__(self, config: ConfigT) -> None: - """Initialize the sync ADK memory store. - - Args: - config: SQLSpec database configuration. - - Notes: - Reads configuration from config.extension_config["adk"]: - - memory_table: Memory table name (default: "adk_memory_entries") - - memory_use_fts: Enable full-text search when supported (default: False) - - memory_max_results: Max search results (default: 20) - - owner_id_column: Optional owner FK column DDL (default: None) - - enable_memory: Whether memory is enabled (default: True) - """ - self._config = config - store_config = self._get_store_config_from_extension() - self._enabled: bool = store_config.get("enable_memory", True) - self._memory_table: str = str(store_config["memory_table"]) - self._use_fts: bool = bool(store_config.get("use_fts", False)) - self._max_results: int = store_config.get("max_results", 20) - self._owner_id_column_ddl: str | None = store_config.get("owner_id_column") - self._owner_id_column_name: str | None = ( - _parse_owner_id_column(self._owner_id_column_ddl) if self._owner_id_column_ddl else None - ) - validate_identifier(self._memory_table, label="table name") - - def _get_store_config_from_extension(self) -> "_ADKMemoryStoreConfig": - """Extract ADK memory configuration from config.extension_config. - - Returns: - Dict with memory_table, use_fts, max_results, and optionally owner_id_column. - """ - return _get_adk_memory_store_config(self._config) - - @property - def config(self) -> ConfigT: - """Return the database configuration.""" - return self._config - - @property - def memory_table(self) -> str: - """Return the memory table name.""" - return self._memory_table - - @property - def enabled(self) -> bool: - """Return whether memory store is enabled.""" - return self._enabled - - @property - def use_fts(self) -> bool: - """Return whether full-text search is enabled.""" - return self._use_fts - - @property - def max_results(self) -> int: - """Return the max search results limit.""" - return self._max_results - - @property - def owner_id_column_ddl(self) -> "str | None": - """Return the full owner ID column DDL (or None if not configured).""" - return self._owner_id_column_ddl - - @property - def owner_id_column_name(self) -> "str | None": - """Return the owner ID column name only (or None if not configured).""" - return self._owner_id_column_name - - @abstractmethod - def create_tables(self) -> None: - """Create the memory table and indexes if they don't exist. - - Should check self._enabled and skip table creation if False. - """ - raise NotImplementedError - - def ensure_tables(self) -> None: - """Create tables when enabled and emit a standardized log entry.""" - - if not self._enabled: - self._log_memory_table_skipped() - return - self.create_tables() - self._log_memory_table_created() - - @abstractmethod - def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication. - - Uses UPSERT pattern to skip duplicates based on event_id. - - Args: - entries: List of memory records to insert. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Number of entries actually inserted (excludes duplicates). - - Raises: - RuntimeError: If memory store is disabled. - """ - raise NotImplementedError - - def _log_memory_table_created(self) -> None: - log_with_context( - logger, - logging.DEBUG, - "adk.memory.table.ready", - db_system=resolve_db_system(type(self).__name__), - memory_table=self._memory_table, - ) - - def _log_memory_table_skipped(self) -> None: - log_with_context( - logger, - logging.DEBUG, - "adk.memory.table.skipped", - db_system=resolve_db_system(type(self).__name__), - memory_table=self._memory_table, - reason="disabled", - ) - - @abstractmethod - def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query. - - Uses the configured search strategy (simple ILIKE or FTS). - - Args: - query: Text query to search for. - app_name: Application name to filter by. - user_id: User ID to filter by. - limit: Maximum number of results (defaults to max_results config). - - Returns: - List of matching memory records ordered by relevance/timestamp. - - Raises: - RuntimeError: If memory store is disabled. - """ - raise NotImplementedError - - @abstractmethod - def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session. - - Args: - session_id: Session ID to delete entries for. - - Returns: - Number of entries deleted. - """ - raise NotImplementedError - - @abstractmethod - def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days. - - Used for TTL cleanup operations. - - Args: - days: Number of days to retain entries. - - Returns: - Number of entries deleted. - """ - raise NotImplementedError - - @abstractmethod - def _get_create_memory_table_sql(self) -> "str | list[str]": - """Get the CREATE TABLE SQL for the memory table. - - Returns: - SQL statement(s) to create the memory table with indexes. - """ - raise NotImplementedError - - @abstractmethod - def _get_drop_memory_table_sql(self) -> "list[str]": - """Get the DROP TABLE SQL statements for this database dialect. - - Returns: - List of SQL statements to drop the memory table and indexes. - """ - raise NotImplementedError diff --git a/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py b/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py index 1dd61a676..a354792cc 100644 --- a/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +++ b/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py @@ -1,6 +1,5 @@ """Create ADK session, events, and memory tables migration using store DDL definitions.""" -import inspect import logging from typing import TYPE_CHECKING, NoReturn, cast @@ -13,7 +12,7 @@ from sqlspec.utils.logging import get_logger, log_with_context if TYPE_CHECKING: - from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore + from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore from sqlspec.extensions.adk.store import BaseAsyncADKStore from sqlspec.migrations.context import MigrationContext @@ -42,9 +41,7 @@ def _get_store_class(context: "MigrationContext | None") -> "type[BaseAsyncADKSt return cast("type[BaseAsyncADKStore]", _get_adk_adapter_store_class(context.config, "ADKStore")) -def _get_memory_store_class( - context: "MigrationContext | None", -) -> "type[BaseAsyncADKMemoryStore | BaseSyncADKMemoryStore] | None": +def _get_memory_store_class(context: "MigrationContext | None") -> "type[BaseAsyncADKMemoryStore] | None": """Get the appropriate memory store class based on the config's module path. Args: @@ -65,7 +62,7 @@ def _get_memory_store_class( if store_class is None: log_with_context(logger, logging.DEBUG, "adk.migration.memory_store.missing") return None - return cast("type[BaseAsyncADKMemoryStore | BaseSyncADKMemoryStore]", store_class) + return cast("type[BaseAsyncADKMemoryStore]", store_class) def _is_memory_enabled(context: "MigrationContext | None") -> bool: @@ -131,9 +128,7 @@ async def up(context: "MigrationContext | None" = None) -> "list[str]": memory_store_class = _get_memory_store_class(context) if memory_store_class is not None: memory_store = memory_store_class(config=context.config) - memory_sql = memory_store._get_create_memory_table_sql() # pyright: ignore[reportPrivateUsage] - if inspect.isawaitable(memory_sql): - memory_sql = await memory_sql + memory_sql = await memory_store._get_create_memory_table_sql() # pyright: ignore[reportPrivateUsage] if isinstance(memory_sql, list): statements.extend(memory_sql) else: diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index 47fe0be64..e8d91f9e8 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -1,4 +1,4 @@ -"""Base store classes for ADK session backend (sync and async).""" +"""Base store class for ADK session backends.""" import logging import re @@ -20,7 +20,7 @@ logger = get_logger("sqlspec.extensions.adk.store") -__all__ = ("BaseAsyncADKStore", "BaseSyncADKStore") +__all__ = ("BaseAsyncADKStore",) COLUMN_NAME_PATTERN: Final = re.compile(r"^(\w+)") @@ -302,257 +302,3 @@ def _log_tables_created(self) -> None: session_table=self._session_table, events_table=self._events_table, ) - - -class BaseSyncADKStore(ABC, Generic[ConfigT]): - """Base class for sync SQLSpec-backed ADK session stores. - - Implements storage operations for Google ADK sessions and events using - SQLSpec database adapters with synchronous execution. - - This abstract base class provides common functionality for sync database-specific - store implementations including: - - Connection management via SQLSpec configs - - Table name validation - - Session and event CRUD operations - - Subclasses must implement dialect-specific SQL queries and will be created - in each adapter directory (e.g., sqlspec/adapters/sqlite/adk/store.py). - - Args: - config: SQLSpec database configuration with extension_config["adk"] settings. - - Notes: - Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) - """ - - __slots__ = ("_config", "_events_table", "_owner_id_column_ddl", "_owner_id_column_name", "_session_table") - - def __init__(self, config: ConfigT) -> None: - """Initialize the sync ADK store. - - Args: - config: SQLSpec database configuration. - - Notes: - Reads configuration from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") - - owner_id_column: Optional owner FK column DDL (default: None) - """ - self._config = config - store_config = self._get_store_config_from_extension() - self._session_table: str = str(store_config["session_table"]) - self._events_table: str = str(store_config["events_table"]) - self._owner_id_column_ddl: str | None = store_config.get("owner_id_column") - self._owner_id_column_name: str | None = ( - _parse_owner_id_column(self._owner_id_column_ddl) if self._owner_id_column_ddl else None - ) - validate_identifier(self._session_table, label="table name") - validate_identifier(self._events_table, label="table name") - - def _get_store_config_from_extension(self) -> "dict[str, Any]": - """Extract ADK store configuration from config.extension_config. - - Returns: - Dict with session_table, events_table, and optionally owner_id_column. - """ - return dict(_get_adk_session_store_config(self._config)) - - @property - def config(self) -> ConfigT: - """Return the database configuration.""" - return self._config - - @property - def session_table(self) -> str: - """Return the sessions table name.""" - return self._session_table - - @property - def events_table(self) -> str: - """Return the events table name.""" - return self._events_table - - @property - def owner_id_column_ddl(self) -> "str | None": - """Return the full owner ID column DDL (or None if not configured).""" - return self._owner_id_column_ddl - - @property - def owner_id_column_name(self) -> "str | None": - """Return the owner ID column name only (or None if not configured).""" - return self._owner_id_column_name - - @abstractmethod - def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> "SessionRecord": - """Create a new session. - - Args: - session_id: Unique identifier for the session. - app_name: Name of the application. - user_id: ID of the user. - state: Session state dictionary. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - The created session record. - """ - raise NotImplementedError - - @abstractmethod - def get_session(self, session_id: str) -> "SessionRecord | None": - """Get a session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record if found, None otherwise. - """ - raise NotImplementedError - - @abstractmethod - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary. - """ - raise NotImplementedError - - @abstractmethod - def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": - """List all sessions for an app, optionally filtered by user. - - Args: - app_name: Name of the application. - user_id: ID of the user. If None, returns all sessions for the app. - - Returns: - List of session records. - """ - raise NotImplementedError - - @abstractmethod - def delete_session(self, session_id: str) -> None: - """Delete a session and its events. - - Args: - session_id: Session identifier. - """ - raise NotImplementedError - - @abstractmethod - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> "EventRecord": - """Create a new event. - - Args: - event_id: Unique event identifier. - session_id: Session identifier. - app_name: Application name. - user_id: User identifier. - author: Event author (user/assistant/system). - actions: Pickled actions object. - content: Event content (JSONB/JSON). - **kwargs: Additional optional fields. - - Returns: - Created event record. - """ - raise NotImplementedError - - @abstractmethod - def create_event_and_update_state( - self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" - ) -> None: - """Atomically create an event and update the session's durable state. - - This is the authoritative durable write boundary for post-creation - session mutations. The event insert and state update must succeed - together or fail together. - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). - """ - raise NotImplementedError - - @abstractmethod - def list_events(self, session_id: str) -> "list[EventRecord]": - """List events for a session ordered by timestamp. - - Args: - session_id: Session identifier. - - Returns: - List of event records ordered by timestamp ASC. - """ - raise NotImplementedError - - @abstractmethod - def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" - raise NotImplementedError - - def ensure_tables(self) -> None: - """Create tables and emit a standardized log entry.""" - - self.create_tables() - self._log_tables_created() - - @abstractmethod - def _get_create_sessions_table_sql(self) -> str: - """Get SQL to create sessions table. - - Returns: - SQL statement to create adk_sessions table with indexes. - """ - raise NotImplementedError - - @abstractmethod - def _get_create_events_table_sql(self) -> str: - """Get SQL to create events table. - - Returns: - SQL statement to create adk_events table with indexes. - """ - raise NotImplementedError - - def _log_tables_created(self) -> None: - log_with_context( - logger, - logging.DEBUG, - "adk.tables.ready", - db_system=resolve_db_system(type(self).__name__), - session_table=self._session_table, - events_table=self._events_table, - ) - - @abstractmethod - def _get_drop_tables_sql(self) -> "list[str]": - """Get SQL to drop tables. - - Returns: - List of SQL statements to drop tables and indexes. - Order matters: drop events before sessions due to FK. - """ - raise NotImplementedError diff --git a/tests/unit/extensions/test_adk/test_store_config.py b/tests/unit/extensions/test_adk/test_store_config.py index 414f5d980..fd35d2c3e 100644 --- a/tests/unit/extensions/test_adk/test_store_config.py +++ b/tests/unit/extensions/test_adk/test_store_config.py @@ -7,13 +7,15 @@ import pytest +import sqlspec.extensions.adk as adk_module from sqlspec.extensions.adk import EventRecord, SessionRecord +from sqlspec.extensions.adk import store as adk_store_module from sqlspec.extensions.adk.artifact._types import ArtifactRecord from sqlspec.extensions.adk.artifact.store import BaseSyncADKArtifactStore from sqlspec.extensions.adk.memory import MemoryRecord from sqlspec.extensions.adk.memory import store as memory_store_module -from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore -from sqlspec.extensions.adk.store import BaseSyncADKStore +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore +from sqlspec.extensions.adk.store import BaseAsyncADKStore class _Config: @@ -29,8 +31,8 @@ def provide_session(self) -> str: return "original-session" -class _SyncSessionStore(BaseSyncADKStore[Any]): - def create_session( +class _AsyncSessionStore(BaseAsyncADKStore[Any]): + async def create_session( self, session_id: str, app_name: str, user_id: str, state: dict[str, Any], owner_id: Any | None = None ) -> SessionRecord: return SessionRecord( @@ -42,77 +44,74 @@ def create_session( update_time=datetime.now(), ) - def get_session(self, session_id: str) -> SessionRecord | None: + async def get_session(self, session_id: str) -> SessionRecord | None: return None - def update_session_state(self, session_id: str, state: dict[str, Any]) -> None: + async def update_session_state(self, session_id: str, state: dict[str, Any]) -> None: return None - def list_sessions(self, app_name: str, user_id: str | None = None) -> list[SessionRecord]: + async def list_sessions(self, app_name: str, user_id: str | None = None) -> list[SessionRecord]: return [] - def delete_session(self, session_id: str) -> None: + async def delete_session(self, session_id: str) -> None: return None - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: str | None = None, - actions: bytes | None = None, - content: dict[str, Any] | None = None, - **kwargs: Any, - ) -> EventRecord: - return EventRecord( - session_id=session_id, - invocation_id=event_id, - author=author or user_id, - timestamp=datetime.now(), - event_data=content or {}, - ) - - def create_event_and_update_state(self, event_record: EventRecord, session_id: str, state: dict[str, Any]) -> None: + async def append_event(self, event_record: EventRecord) -> None: return None - def list_events(self, session_id: str) -> list[EventRecord]: + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: dict[str, Any] + ) -> SessionRecord: + return SessionRecord( + id=session_id, + app_name="test-app", + user_id="test-user", + state=state, + create_time=datetime.now(), + update_time=datetime.now(), + ) + + async def get_events( + self, session_id: str, after_timestamp: datetime | None = None, limit: int | None = None + ) -> list[EventRecord]: return [] - def create_tables(self) -> None: + async def create_tables(self) -> None: return None - def _get_create_sessions_table_sql(self) -> str: + async def _get_create_sessions_table_sql(self) -> str: return "" - def _get_create_events_table_sql(self) -> str: + async def _get_create_events_table_sql(self) -> str: return "" def _get_drop_tables_sql(self) -> list[str]: return [] -class _SyncMemoryStore(BaseSyncADKMemoryStore[Any]): +class _AsyncMemoryStore(BaseAsyncADKMemoryStore[Any]): def __init__(self, config: _Config) -> None: super().__init__(config) self.create_tables_called = False - def create_tables(self) -> None: + async def create_tables(self) -> None: self.create_tables_called = True - def insert_memory_entries(self, entries: list[MemoryRecord], owner_id: object | None = None) -> int: + async def insert_memory_entries(self, entries: list[MemoryRecord], owner_id: object | None = None) -> int: return len(entries) - def search_entries(self, query: str, app_name: str, user_id: str, limit: int | None = None) -> list[MemoryRecord]: + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: int | None = None + ) -> list[MemoryRecord]: return [] - def delete_entries_by_session(self, session_id: str) -> int: + async def delete_entries_by_session(self, session_id: str) -> int: return 0 - def delete_entries_older_than(self, days: int) -> int: + async def delete_entries_older_than(self, days: int) -> int: return 0 - def _get_create_memory_table_sql(self) -> str | list[str]: + async def _get_create_memory_table_sql(self) -> str | list[str]: return "" def _get_drop_memory_table_sql(self) -> list[str]: @@ -148,7 +147,7 @@ def create_table(self) -> None: return None -@pytest.mark.parametrize("store_cls", [_SyncSessionStore, _SyncMemoryStore, _SyncArtifactStore]) +@pytest.mark.parametrize("store_cls", [_AsyncSessionStore, _AsyncMemoryStore, _SyncArtifactStore]) def test_adk_base_stores_keep_original_config(store_cls: type[Any]) -> None: config = _Config() store = store_cls(config) @@ -156,16 +155,30 @@ def test_adk_base_stores_keep_original_config(store_cls: type[Any]) -> None: assert store.config is config -def test_sync_memory_store_logs_ready_with_log_with_context(monkeypatch: pytest.MonkeyPatch) -> None: +def test_session_store_contract_exports_async_surface_only() -> None: + assert "BaseSyncADKStore" not in adk_module.__all__ + assert "BaseSyncADKStore" not in adk_store_module.__all__ + assert not hasattr(adk_module, "BaseSyncADKStore") + assert not hasattr(adk_store_module, "BaseSyncADKStore") + + +def test_memory_store_contract_exports_async_surface_only() -> None: + assert "BaseSyncADKMemoryStore" not in adk_module.__all__ + assert "BaseSyncADKMemoryStore" not in memory_store_module.__all__ + assert not hasattr(adk_module, "BaseSyncADKMemoryStore") + assert not hasattr(memory_store_module, "BaseSyncADKMemoryStore") + + +async def test_async_memory_store_logs_ready_with_log_with_context(monkeypatch: pytest.MonkeyPatch) -> None: calls: list[dict[str, Any]] = [] def fake_log_with_context(logger: Any, level: int, event: str, **context: Any) -> None: calls.append({"level": level, "event": event, "context": context}) monkeypatch.setattr(memory_store_module, "log_with_context", fake_log_with_context) - store = _SyncMemoryStore(_Config({"memory_table": "test_memories"})) + store = _AsyncMemoryStore(_Config({"memory_table": "test_memories"})) - store.ensure_tables() + await store.ensure_tables() assert store.create_tables_called assert len(calls) == 1 @@ -175,16 +188,16 @@ def fake_log_with_context(logger: Any, level: int, event: str, **context: Any) - assert "db_system" in calls[0]["context"] -def test_sync_memory_store_logs_disabled_with_log_with_context(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_async_memory_store_logs_disabled_with_log_with_context(monkeypatch: pytest.MonkeyPatch) -> None: calls: list[dict[str, Any]] = [] def fake_log_with_context(logger: Any, level: int, event: str, **context: Any) -> None: calls.append({"level": level, "event": event, "context": context}) monkeypatch.setattr(memory_store_module, "log_with_context", fake_log_with_context) - store = _SyncMemoryStore(_Config({"enable_memory": False, "memory_table": "test_memories"})) + store = _AsyncMemoryStore(_Config({"enable_memory": False, "memory_table": "test_memories"})) - store.ensure_tables() + await store.ensure_tables() assert not store.create_tables_called assert len(calls) == 1 diff --git a/tests/unit/extensions/test_adk/test_store_instantiation.py b/tests/unit/extensions/test_adk/test_store_instantiation.py index 21a79a4fa..2bfbd9aa5 100644 --- a/tests/unit/extensions/test_adk/test_store_instantiation.py +++ b/tests/unit/extensions/test_adk/test_store_instantiation.py @@ -16,8 +16,8 @@ import pytest from sqlspec.exceptions import SQLSpecError -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore -from sqlspec.extensions.adk.store import BaseAsyncADKStore, BaseSyncADKStore +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore +from sqlspec.extensions.adk.store import BaseAsyncADKStore SESSION_STORE_CLASSES = [ "sqlspec.adapters.asyncpg.adk.AsyncpgADKStore", @@ -96,12 +96,11 @@ def test_store_method_signatures_match_base_contract(class_path: str) -> None: if issubclass(cls, BaseAsyncADKStore): base: type = BaseAsyncADKStore - elif issubclass(cls, BaseSyncADKStore): - base = BaseSyncADKStore elif issubclass(cls, BaseAsyncADKMemoryStore): base = BaseAsyncADKMemoryStore else: - base = BaseSyncADKMemoryStore + msg = f"{class_path} must inherit from the async ADK store contract" + raise AssertionError(msg) for method_name in base.__abstractmethods__: base_signature = inspect.signature(getattr(base, method_name)) From 5f0f8a89237ec33eca60d66cd88c906e26e7a1ce Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 19:21:38 +0000 Subject: [PATCH 06/29] feat(adk): add shared store contract tests --- sqlspec/adapters/duckdb/adk/store.py | 21 +- .../adapters/_adk_contract_helpers.py | 277 ++++++++++++++++++ .../extensions/adk/test_memory_store.py | 14 + .../aiosqlite/extensions/adk/test_store.py | 10 + .../extensions/adk/test_memory_store.py | 7 + .../duckdb/extensions/adk/test_store.py | 6 + .../extensions/adk/test_memory_store.py | 14 + .../sqlite/extensions/adk/test_store.py | 10 + 8 files changed, 349 insertions(+), 10 deletions(-) create mode 100644 tests/integration/adapters/_adk_contract_helpers.py diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py index fb689758e..995cd0e99 100644 --- a/sqlspec/adapters/duckdb/adk/store.py +++ b/sqlspec/adapters/duckdb/adk/store.py @@ -462,18 +462,19 @@ def _append_event_and_update_state( """ with self._config.provide_connection() as conn: - conn.execute( - insert_sql, - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - event_data_str, - ), - ) cursor = conn.execute(update_sql, (state_json, now, session_id)) row = cursor.fetchone() + if row is not None: + conn.execute( + insert_sql, + ( + event_record["session_id"], + event_record["invocation_id"], + event_record["author"], + event_record["timestamp"], + event_data_str, + ), + ) conn.commit() if row is None: diff --git a/tests/integration/adapters/_adk_contract_helpers.py b/tests/integration/adapters/_adk_contract_helpers.py new file mode 100644 index 000000000..6fd1e37a9 --- /dev/null +++ b/tests/integration/adapters/_adk_contract_helpers.py @@ -0,0 +1,277 @@ +"""Shared acceptance helpers for ADK adapter integration tests.""" + +from datetime import datetime, timedelta, timezone +from typing import Protocol +from uuid import uuid4 + +from sqlspec.extensions.adk import EventRecord, MemoryRecord, SessionRecord + +__all__ = ("assert_memory_store_contract", "assert_session_event_store_contract") + + +class SessionEventStore(Protocol): + """Minimal ADK session/event store surface used by contract tests.""" + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: dict[str, object], owner_id: object | None = None + ) -> SessionRecord: ... + + async def get_session(self, session_id: str) -> SessionRecord | None: ... + + async def update_session_state(self, session_id: str, state: dict[str, object]) -> None: ... + + async def list_sessions(self, app_name: str, user_id: str | None = None) -> list[SessionRecord]: ... + + async def delete_session(self, session_id: str) -> None: ... + + async def append_event(self, event_record: EventRecord) -> None: ... + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: dict[str, object] + ) -> SessionRecord: ... + + async def get_events( + self, session_id: str, after_timestamp: datetime | None = None, limit: int | None = None + ) -> list[EventRecord]: ... + + +class MemoryStore(Protocol): + """Minimal ADK memory store surface used by contract tests.""" + + async def insert_memory_entries(self, entries: list[MemoryRecord], owner_id: object | None = None) -> int: ... + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: int | None = None + ) -> list[MemoryRecord]: ... + + async def delete_entries_by_session(self, session_id: str) -> int: ... + + async def delete_entries_older_than(self, days: int) -> int: ... + + +def _contract_key(marker: str, suffix: str) -> str: + return f"adk-contract-{marker}-{suffix}-{uuid4().hex}" + + +def _event_record( + *, + session_id: str, + event_id: str, + invocation_id: str, + author: str, + timestamp: datetime, + event_data: dict[str, object], +) -> EventRecord: + data = dict(event_data) + data.setdefault("id", event_id) + return { + "session_id": session_id, + "invocation_id": invocation_id, + "author": author, + "timestamp": timestamp, + "event_data": data, + } + + +def _memory_record( + *, + marker: str, + session_id: str, + app_name: str, + user_id: str, + event_id: str, + content_text: str, + inserted_at: datetime, + metadata_json: dict[str, object] | None = None, +) -> MemoryRecord: + return { + "id": _contract_key(marker, "memory"), + "session_id": session_id, + "app_name": app_name, + "user_id": user_id, + "event_id": event_id, + "author": "user", + "timestamp": inserted_at, + "content_json": {"parts": [{"text": content_text}]}, + "content_text": content_text, + "metadata_json": metadata_json, + "inserted_at": inserted_at, + } + + +def _event_data(record: EventRecord) -> dict[str, object]: + value = record["event_data"] + assert isinstance(value, dict) + return value + + +async def assert_session_event_store_contract(store: SessionEventStore, *, marker: str) -> None: + """Assert the shared ADK session/event store acceptance contract. + + Backend-specific integration tests call this helper after creating tables. + It keeps core session/event expectations identical across async and + sync-driver-backed stores. + """ + app_name = _contract_key(marker, "app") + user_id = _contract_key(marker, "user") + session_id = _contract_key(marker, "session") + base_time = datetime(2026, 5, 23, 12, 0, tzinfo=timezone.utc) + + created = await store.create_session(session_id, app_name, user_id, {"created": True}) + assert created["id"] == session_id + assert created["app_name"] == app_name + assert created["user_id"] == user_id + assert created["state"] == {"created": True} + + first_event = _event_record( + session_id=session_id, + event_id="contract-event-1", + invocation_id="contract-inv-1", + author="user", + timestamp=base_time, + event_data={ + "output": {"kind": "text", "value": "captured by full-event JSON"}, + "node_info": {"node_path": ["root", "agent"]}, + "actions": { + "route": "next", + "request_task": {"id": "task-1"}, + "finish_task": {"id": "task-1"}, + "state_delta": {"turn": 1}, + }, + }, + ) + updated = await store.append_event_and_update_state(first_event, session_id, {"turn": 1}) + assert updated["id"] == session_id + assert updated["state"] == {"turn": 1} + + fetched = await store.get_session(session_id) + assert fetched is not None + assert fetched["state"] == {"turn": 1} + + stored_events = await store.get_events(session_id) + assert len(stored_events) == 1 + assert stored_events[0]["invocation_id"] == "contract-inv-1" + first_data = _event_data(stored_events[0]) + assert first_data["output"] == {"kind": "text", "value": "captured by full-event JSON"} + assert first_data["node_info"] == {"node_path": ["root", "agent"]} + assert first_data["actions"] == { + "route": "next", + "request_task": {"id": "task-1"}, + "finish_task": {"id": "task-1"}, + "state_delta": {"turn": 1}, + } + + await store.append_event( + _event_record( + session_id=session_id, + event_id="contract-event-2", + invocation_id="contract-inv-2", + author="model", + timestamp=base_time + timedelta(seconds=1), + event_data={"content": {"parts": [{"text": "second"}]}}, + ) + ) + await store.append_event( + _event_record( + session_id=session_id, + event_id="contract-event-3", + invocation_id="contract-inv-3", + author="model", + timestamp=base_time + timedelta(seconds=2), + event_data={"content": {"parts": [{"text": "third"}]}}, + ) + ) + + filtered = await store.get_events(session_id, after_timestamp=base_time + timedelta(milliseconds=500), limit=1) + assert [event["invocation_id"] for event in filtered] == ["contract-inv-2"] + + listed = await store.list_sessions(app_name, user_id) + assert any(record["id"] == session_id for record in listed) + + await store.delete_session(session_id) + assert await store.get_session(session_id) is None + assert await store.get_events(session_id) == [] + + +async def assert_memory_store_contract(store: MemoryStore, *, marker: str) -> None: + """Assert the shared ADK memory store acceptance contract.""" + app_name = _contract_key(marker, "app") + user_id = _contract_key(marker, "user") + other_user_id = _contract_key(marker, "other-user") + session_id = _contract_key(marker, "session") + other_session_id = _contract_key(marker, "other-session") + now = datetime.now(timezone.utc) + + espresso = _memory_record( + marker=marker, + session_id=session_id, + app_name=app_name, + user_id=user_id, + event_id=_contract_key(marker, "event-espresso"), + content_text="espresso roast contract memory", + inserted_at=now, + metadata_json={"source": "contract", "priority": 2}, + ) + latte = _memory_record( + marker=marker, + session_id=session_id, + app_name=app_name, + user_id=user_id, + event_id=_contract_key(marker, "event-latte"), + content_text="latte foam contract memory", + inserted_at=now, + ) + other_user = _memory_record( + marker=marker, + session_id=other_session_id, + app_name=app_name, + user_id=other_user_id, + event_id=_contract_key(marker, "event-other"), + content_text="espresso roast contract memory", + inserted_at=now, + ) + + inserted = await store.insert_memory_entries([espresso, latte, other_user]) + assert inserted == 3 + + duplicate_count = await store.insert_memory_entries([espresso]) + assert duplicate_count == 0 + + results = await store.search_entries("espresso", app_name, user_id, limit=10) + assert len(results) == 1 + assert results[0]["event_id"] == espresso["event_id"] + assert results[0]["metadata_json"] == {"source": "contract", "priority": 2} + + other_results = await store.search_entries("espresso", app_name, other_user_id, limit=10) + assert len(other_results) == 1 + assert other_results[0]["event_id"] == other_user["event_id"] + + deleted_session = await store.delete_entries_by_session(session_id) + assert deleted_session == 2 + assert await store.search_entries("latte", app_name, user_id, limit=10) == [] + + old_record = _memory_record( + marker=marker, + session_id=_contract_key(marker, "old-session"), + app_name=app_name, + user_id=user_id, + event_id=_contract_key(marker, "event-old"), + content_text="old contract memory", + inserted_at=now - timedelta(days=40), + ) + fresh_record = _memory_record( + marker=marker, + session_id=_contract_key(marker, "fresh-session"), + app_name=app_name, + user_id=user_id, + event_id=_contract_key(marker, "event-fresh"), + content_text="fresh contract memory", + inserted_at=now, + ) + assert await store.insert_memory_entries([old_record, fresh_record]) == 2 + + deleted_old = await store.delete_entries_older_than(30) + assert deleted_old == 1 + fresh_results = await store.search_entries("fresh", app_name, user_id, limit=10) + assert len(fresh_results) == 1 + assert fresh_results[0]["event_id"] == fresh_record["event_id"] diff --git a/tests/integration/adapters/aiosqlite/extensions/adk/test_memory_store.py b/tests/integration/adapters/aiosqlite/extensions/adk/test_memory_store.py index 3ce5f04d9..0bc65b2e0 100644 --- a/tests/integration/adapters/aiosqlite/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/aiosqlite/extensions/adk/test_memory_store.py @@ -9,6 +9,7 @@ from sqlspec.adapters.aiosqlite import AiosqliteConfig from sqlspec.adapters.aiosqlite.adk import AiosqliteADKMemoryStore from sqlspec.extensions.adk import MemoryRecord +from tests.integration.adapters._adk_contract_helpers import assert_memory_store_contract pytestmark = pytest.mark.xdist_group("sqlite") @@ -54,6 +55,19 @@ async def test_aiosqlite_memory_store_insert_search_dedup() -> None: await config.close_pool() +async def test_aiosqlite_memory_store_shared_contract() -> None: + """AioSQLite satisfies the shared ADK memory store acceptance contract.""" + with tempfile.NamedTemporaryFile(suffix=".db") as tmp: + config = AiosqliteConfig(connection_config={"database": tmp.name}) + store = AiosqliteADKMemoryStore(config) + await store.create_tables() + + try: + await assert_memory_store_contract(store, marker="aiosqlite") + finally: + await config.close_pool() + + async def test_aiosqlite_memory_store_fts_search() -> None: """FTS-enabled memory stores search through the FTS5 virtual table.""" with tempfile.NamedTemporaryFile(suffix=".db") as tmp: diff --git a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py index 35d169f96..ba9bdb174 100644 --- a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py @@ -8,6 +8,7 @@ from sqlspec.adapters.aiosqlite import AiosqliteConfig from sqlspec.adapters.aiosqlite.adk import AiosqliteADKStore from sqlspec.extensions.adk import EventRecord +from tests.integration.adapters._adk_contract_helpers import assert_session_event_store_contract pytestmark = pytest.mark.xdist_group("sqlite") @@ -54,6 +55,15 @@ async def test_aiosqlite_session_empty_state_round_trip(tmp_path: Path) -> None: await config.close_pool() +async def test_aiosqlite_session_event_store_shared_contract(tmp_path: Path) -> None: + """AioSQLite satisfies the shared ADK session/event store acceptance contract.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_event_store_contract(store, marker="aiosqlite") + finally: + await config.close_pool() + + async def test_aiosqlite_append_event_and_update_state_is_atomic_contract(tmp_path: Path) -> None: """Event append and durable state update happen through the clean-break method.""" config, store = await _build_store(tmp_path) diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py index b5a2b6ca5..b626efc88 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py @@ -10,6 +10,7 @@ from sqlspec.adapters.duckdb.adk import DuckdbADKMemoryStore from sqlspec.adapters.duckdb.config import DuckDBConfig from sqlspec.extensions.adk import MemoryRecord +from tests.integration.adapters._adk_contract_helpers import assert_memory_store_contract pytestmark = [pytest.mark.duckdb, pytest.mark.integration] @@ -71,6 +72,12 @@ async def test_duckdb_memory_store_insert_search_dedup(tmp_path: Path) -> None: assert deduped == 0 +async def test_duckdb_memory_store_shared_contract(tmp_path: Path) -> None: + """DuckDB satisfies the shared ADK memory store acceptance contract.""" + store = await _build_store(tmp_path) + await assert_memory_store_contract(store, marker="duckdb") + + async def test_duckdb_memory_store_delete_by_session(tmp_path: Path) -> None: """Delete memory entries by session id.""" store = await _build_store(tmp_path) diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index 0fb48e0ae..423e23a16 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -10,6 +10,7 @@ from sqlspec.adapters.duckdb.adk import DuckdbADKStore from sqlspec.adapters.duckdb.config import DuckDBConfig from sqlspec.extensions.adk import EventRecord +from tests.integration.adapters._adk_contract_helpers import assert_session_event_store_contract pytestmark = [pytest.mark.duckdb, pytest.mark.integration] @@ -47,6 +48,11 @@ async def test_create_tables(duckdb_adk_store: DuckdbADKStore) -> None: assert duckdb_adk_store.events_table == "test_events" +async def test_duckdb_session_event_store_shared_contract(duckdb_adk_store: DuckdbADKStore) -> None: + """DuckDB satisfies the shared ADK session/event store acceptance contract.""" + await assert_session_event_store_contract(duckdb_adk_store, marker="duckdb") + + async def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: """Test creating and retrieving a session.""" session_id = "session-001" diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py index 492c4a18c..19dc7fc96 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py @@ -9,6 +9,7 @@ from sqlspec.adapters.sqlite import SqliteConfig from sqlspec.adapters.sqlite.adk import SqliteADKMemoryStore from sqlspec.extensions.adk import MemoryRecord +from tests.integration.adapters._adk_contract_helpers import assert_memory_store_contract pytestmark = pytest.mark.xdist_group("sqlite") @@ -52,6 +53,19 @@ async def test_sqlite_memory_store_insert_search_dedup() -> None: assert deduped == 0 +async def test_sqlite_memory_store_shared_contract() -> None: + """SQLite satisfies the shared ADK memory store acceptance contract.""" + with tempfile.NamedTemporaryFile(suffix=".db") as tmp: + config = SqliteConfig(connection_config={"database": tmp.name}) + store = SqliteADKMemoryStore(config) + await store.create_tables() + + try: + await assert_memory_store_contract(store, marker="sqlite") + finally: + config.close_pool() + + async def test_sqlite_memory_store_fts_search() -> None: """FTS-enabled memory stores search through the FTS5 virtual table.""" with tempfile.NamedTemporaryFile(suffix=".db") as tmp: diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_store.py index 3e31fab35..7a22a3b3f 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_store.py @@ -8,6 +8,7 @@ from sqlspec.adapters.sqlite import SqliteConfig from sqlspec.adapters.sqlite.adk import SqliteADKStore from sqlspec.extensions.adk import EventRecord +from tests.integration.adapters._adk_contract_helpers import assert_session_event_store_contract pytestmark = pytest.mark.xdist_group("sqlite") @@ -34,6 +35,15 @@ async def test_sqlite_session_empty_state_round_trip(tmp_path: Path) -> None: config.close_pool() +async def test_sqlite_session_event_store_shared_contract(tmp_path: Path) -> None: + """SQLite satisfies the shared ADK session/event store acceptance contract.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_event_store_contract(store, marker="sqlite") + finally: + config.close_pool() + + async def test_sqlite_append_event_and_update_state_is_atomic_contract(tmp_path: Path) -> None: """Event append and durable state update happen through the clean-break method.""" config, store = await _build_store(tmp_path) From 02152dd3ce0b8387bf66286458a5c1e78c5a56d5 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 19:34:02 +0000 Subject: [PATCH 07/29] feat(adk): define nested config model --- sqlspec/config.py | 127 +++++++++++++++++- sqlspec/extensions/adk/_config_utils.py | 85 ++++++++++-- .../test_adk/test_config_resolution.py | 79 +++++++++++ 3 files changed, 274 insertions(+), 17 deletions(-) create mode 100644 tests/unit/extensions/test_adk/test_config_resolution.py diff --git a/sqlspec/config.py b/sqlspec/config.py index 3e201b3d2..775522c8b 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -45,11 +45,7 @@ __all__ = ( - "ADKCompressionConfig", "ADKConfig", - "ADKPartitionConfig", - "ADKRetentionConfig", - "ADKSqliteOptimizationConfig", "AsyncConfigT", "AsyncDatabaseConfig", "ConfigT", @@ -683,6 +679,102 @@ class ADKSqliteOptimizationConfig(TypedDict): """ +ADKOptimizationMode: TypeAlias = Literal["auto", "enable", "disable"] +"""Tri-state optimization control used by ADK capability negotiation.""" + + +class ADKSchemaConfig(TypedDict): + """Shared ADK schema naming and migration controls.""" + + session_table: NotRequired[str] + events_table: NotRequired[str] + memory_table: NotRequired[str] + artifact_table: NotRequired[str] + app_state_table: NotRequired[str] + user_state_table: NotRequired[str] + metadata_table: NotRequired[str] + owner_id_column: NotRequired[str] + schema_version: NotRequired[int] + include_sessions_migration: NotRequired[bool] + include_memory_migration: NotRequired[bool] + include_artifact_migration: NotRequired[bool] + + +class ADKSearchConfig(TypedDict): + """Shared ADK search configuration.""" + + strategy: NotRequired[Literal["auto", "like", "fts", "vector"]] + use_fts: NotRequired[bool] + language: NotRequired[str] + max_results: NotRequired[int] + + +class ADKMemoryConfig(TypedDict): + """Shared ADK memory configuration.""" + + enabled: NotRequired[bool] + table: NotRequired[str] + max_results: NotRequired[int] + search: NotRequired[ADKSearchConfig] + + +class ADKArtifactConfig(TypedDict): + """Shared ADK artifact configuration.""" + + table: NotRequired[str] + storage_uri: NotRequired[str] + + +class ADKOptimizationConfig(TypedDict): + """Shared ADK data-model optimization controls.""" + + generated_columns: NotRequired[ADKOptimizationMode] + null_encoded_empty_state: NotRequired[ADKOptimizationMode] + skip_noop_session_update: NotRequired[ADKOptimizationMode] + append_only_event_partitioning: NotRequired[ADKOptimizationMode] + covering_indexes: NotRequired[ADKOptimizationMode] + duckdb_struct_events: NotRequired[ADKOptimizationMode] + spanner_commit_timestamp_pk_suffix: NotRequired[ADKOptimizationMode] + alloydb_columnar_autopromote: NotRequired[ADKOptimizationMode] + + +class ADKOracleConfig(TypedDict): + """Oracle-specific ADK capability settings.""" + + in_memory: NotRequired[bool] + session_table_options: NotRequired[str] + events_table_options: NotRequired[str] + memory_table_options: NotRequired[str] + compression: NotRequired[ADKCompressionConfig] + partitioning: NotRequired[ADKPartitionConfig] + + +class ADKSpannerConfig(TypedDict): + """Spanner-specific ADK capability settings.""" + + shard_count: NotRequired[int] + interleave_events_in_sessions: NotRequired[bool] + session_table_options: NotRequired[str] + events_table_options: NotRequired[str] + memory_table_options: NotRequired[str] + + +class ADKADBCConfig(TypedDict): + """ADBC-specific ADK capability settings.""" + + dialect: NotRequired[str] + table_options: NotRequired[str] + + +class ADKBigQueryConfig(TypedDict): + """BigQuery-specific ADK capability settings.""" + + dataset: NotRequired[str] + partition_expiration_days: NotRequired[int] + clustering_fields: NotRequired[tuple[str, ...] | list[str]] + table_options: NotRequired[str] + + class ADKConfig(TypedDict): """Configuration options for ADK session and memory store extension. @@ -714,6 +806,33 @@ class ADKConfig(TypedDict): You can use plain dicts as well. """ + schema: NotRequired[ADKSchemaConfig] + """Shared schema naming, owner binding, and migration controls.""" + + memory: NotRequired[ADKMemoryConfig] + """Shared memory enablement and result-limit controls.""" + + search: NotRequired[ADKSearchConfig] + """Shared search strategy and language controls.""" + + artifact: NotRequired[ADKArtifactConfig] + """Shared artifact metadata and storage URI controls.""" + + optimizations: NotRequired[ADKOptimizationConfig] + """Shared optimization negotiation controls.""" + + oracle: NotRequired[ADKOracleConfig] + """Oracle-specific ADK capability settings.""" + + spanner: NotRequired[ADKSpannerConfig] + """Spanner-specific ADK capability settings.""" + + adbc: NotRequired[ADKADBCConfig] + """ADBC-specific ADK capability settings.""" + + bigquery: NotRequired[ADKBigQueryConfig] + """BigQuery-specific ADK capability settings.""" + enable_sessions: NotRequired[bool] """Enable session store at runtime. Default: True. diff --git a/sqlspec/extensions/adk/_config_utils.py b/sqlspec/extensions/adk/_config_utils.py index eb63fafa6..f6ae820cf 100644 --- a/sqlspec/extensions/adk/_config_utils.py +++ b/sqlspec/extensions/adk/_config_utils.py @@ -28,6 +28,9 @@ class _ADKSessionStoreConfig(TypedDict): session_table: str events_table: str + app_state_table: str + user_state_table: str + metadata_table: str owner_id_column: NotRequired[str] @@ -45,6 +48,7 @@ class _ADKArtifactStoreConfig(TypedDict): """Normalized ADK artifact store configuration.""" artifact_table: str + storage_uri: NotRequired[str] class _ADKConfigSource(Protocol): @@ -62,17 +66,40 @@ def _get_adk_config_from_extension(config: _ADKConfigSource) -> dict[str, Any]: return dict(cast("dict[str, Any]", config.extension_config.get("adk", {}))) +def _get_adk_config_section(adk_config: dict[str, Any], name: str) -> dict[str, Any]: + """Return a mutable nested ADK config section.""" + + value = adk_config.get(name) + return dict(cast("dict[str, Any]", value)) if isinstance(value, dict) else {} + + +def _get_first_value(*values: Any, default: Any = None) -> Any: + """Return the first non-None value.""" + + for value in values: + if value is not None: + return value + return default + + def _get_adk_session_store_config(config: _ADKConfigSource) -> _ADKSessionStoreConfig: """Return normalized session store table settings.""" adk_config = _get_adk_config_from_extension(config) - session_table = adk_config.get("session_table") - events_table = adk_config.get("events_table") + schema_config = _get_adk_config_section(adk_config, "schema") + session_table = _get_first_value(schema_config.get("session_table"), adk_config.get("session_table")) + events_table = _get_first_value(schema_config.get("events_table"), adk_config.get("events_table")) + app_state_table = _get_first_value(schema_config.get("app_state_table"), adk_config.get("app_state_table")) + user_state_table = _get_first_value(schema_config.get("user_state_table"), adk_config.get("user_state_table")) + metadata_table = _get_first_value(schema_config.get("metadata_table"), adk_config.get("metadata_table")) result: _ADKSessionStoreConfig = { "session_table": str(session_table) if session_table is not None else "adk_sessions", "events_table": str(events_table) if events_table is not None else "adk_events", + "app_state_table": str(app_state_table) if app_state_table is not None else "adk_app_states", + "user_state_table": str(user_state_table) if user_state_table is not None else "adk_user_states", + "metadata_table": str(metadata_table) if metadata_table is not None else "adk_internal_metadata", } - owner_id = adk_config.get("owner_id_column") + owner_id = _get_first_value(schema_config.get("owner_id_column"), adk_config.get("owner_id_column")) if owner_id is not None: result["owner_id_column"] = cast("str", owner_id) return result @@ -82,18 +109,34 @@ def _get_adk_memory_store_config(config: _ADKConfigSource) -> _ADKMemoryStoreCon """Return normalized memory store settings.""" adk_config = _get_adk_config_from_extension(config) - enable_memory = adk_config.get("enable_memory") - memory_table = adk_config.get("memory_table") - use_fts = adk_config.get("memory_use_fts") - max_results = adk_config.get("memory_max_results") + schema_config = _get_adk_config_section(adk_config, "schema") + memory_config = _get_adk_config_section(adk_config, "memory") + search_config = _get_adk_config_section(adk_config, "search") + nested_memory_search_config = _get_adk_config_section(memory_config, "search") + enable_memory = _get_first_value(memory_config.get("enabled"), adk_config.get("enable_memory")) + memory_table = _get_first_value( + memory_config.get("table"), schema_config.get("memory_table"), adk_config.get("memory_table") + ) + use_fts = _get_first_value( + nested_memory_search_config.get("use_fts"), + search_config.get("use_fts"), + memory_config.get("use_fts"), + adk_config.get("memory_use_fts"), + ) + max_results = _get_first_value( + memory_config.get("max_results"), + nested_memory_search_config.get("max_results"), + search_config.get("max_results"), + adk_config.get("memory_max_results"), + ) result: _ADKMemoryStoreConfig = { "enable_memory": bool(enable_memory) if enable_memory is not None else True, "memory_table": str(memory_table) if memory_table is not None else "adk_memory_entries", "use_fts": bool(use_fts) if use_fts is not None else False, - "max_results": int(max_results) if isinstance(max_results, int) else 20, + "max_results": int(max_results) if type(max_results) is int else 20, } - owner_id = adk_config.get("owner_id_column") + owner_id = _get_first_value(schema_config.get("owner_id_column"), adk_config.get("owner_id_column")) if owner_id is not None: result["owner_id_column"] = cast("str", owner_id) return result @@ -103,8 +146,18 @@ def _get_adk_artifact_store_config(config: _ADKConfigSource) -> _ADKArtifactStor """Return normalized artifact store settings.""" adk_config = _get_adk_config_from_extension(config) - artifact_table = adk_config.get("artifact_table") - return {"artifact_table": str(artifact_table) if artifact_table is not None else "adk_artifact_versions"} + schema_config = _get_adk_config_section(adk_config, "schema") + artifact_config = _get_adk_config_section(adk_config, "artifact") + artifact_table = _get_first_value( + artifact_config.get("table"), schema_config.get("artifact_table"), adk_config.get("artifact_table") + ) + result: _ADKArtifactStoreConfig = { + "artifact_table": str(artifact_table) if artifact_table is not None else "adk_artifact_versions" + } + storage_uri = _get_first_value(artifact_config.get("storage_uri"), adk_config.get("artifact_storage_uri")) + if storage_uri is not None: + result["storage_uri"] = str(storage_uri) + return result def _resolve_adk_store_path(config: Any, store_suffix: str) -> str: @@ -175,10 +228,16 @@ def _is_adk_memory_migration_enabled(config: Any) -> bool: """Return whether ADK memory DDL should be included for this config.""" adk_config = _get_adk_config_from_extension(cast("_ADKConfigSource", config)) - include_memory = adk_config.get("include_memory_migration") + schema_config = _get_adk_config_section(adk_config, "schema") + memory_config = _get_adk_config_section(adk_config, "memory") + include_memory = _get_first_value( + schema_config.get("include_memory_migration"), + memory_config.get("include_migration"), + adk_config.get("include_memory_migration"), + ) if include_memory is not None: return bool(include_memory) - return bool(adk_config.get("enable_memory", True)) + return bool(_get_first_value(memory_config.get("enabled"), adk_config.get("enable_memory"), default=True)) def _validate_adk_store_registration(config: Any) -> None: diff --git a/tests/unit/extensions/test_adk/test_config_resolution.py b/tests/unit/extensions/test_adk/test_config_resolution.py new file mode 100644 index 000000000..7e84d2e15 --- /dev/null +++ b/tests/unit/extensions/test_adk/test_config_resolution.py @@ -0,0 +1,79 @@ +"""Tests for ADK clean-break configuration resolution.""" + +from typing import Any + +from sqlspec.config import ADKConfig +from sqlspec.extensions.adk._config_utils import ( + _get_adk_artifact_store_config, + _get_adk_memory_store_config, + _get_adk_session_store_config, + _is_adk_memory_migration_enabled, +) + + +class _Config: + extension_config: dict[str, dict[str, Any]] + + def __init__(self, adk_config: dict[str, Any]) -> None: + self.extension_config = {"adk": adk_config} + + +def test_adk_config_declares_nested_capability_sections() -> None: + expected = {"schema", "memory", "search", "artifact", "optimizations", "oracle", "spanner", "adbc", "bigquery"} + + assert expected <= set(ADKConfig.__annotations__) + + +def test_nested_schema_config_resolves_all_adk_table_names() -> None: + config = _Config({ + "session_table": "flat_sessions", + "schema": { + "session_table": "agent_sessions", + "events_table": "agent_events", + "app_state_table": "agent_app_states", + "user_state_table": "agent_user_states", + "metadata_table": "agent_metadata", + "owner_id_column": "tenant_id UUID", + }, + }) + + resolved = _get_adk_session_store_config(config) + + assert resolved == { + "session_table": "agent_sessions", + "events_table": "agent_events", + "app_state_table": "agent_app_states", + "user_state_table": "agent_user_states", + "metadata_table": "agent_metadata", + "owner_id_column": "tenant_id UUID", + } + + +def test_nested_memory_and_search_config_resolve_memory_store_settings() -> None: + config = _Config({ + "enable_memory": True, + "memory": {"enabled": False, "table": "agent_memories", "max_results": 50}, + "search": {"use_fts": True, "language": "simple"}, + }) + + resolved = _get_adk_memory_store_config(config) + + assert resolved == {"enable_memory": False, "memory_table": "agent_memories", "use_fts": True, "max_results": 50} + + +def test_nested_artifact_config_resolves_table_and_storage_uri() -> None: + config = _Config({ + "artifact_table": "flat_artifacts", + "artifact_storage_uri": "file:///flat", + "artifact": {"table": "agent_artifacts", "storage_uri": "s3://bucket/adk"}, + }) + + resolved = _get_adk_artifact_store_config(config) + + assert resolved == {"artifact_table": "agent_artifacts", "storage_uri": "s3://bucket/adk"} + + +def test_schema_include_memory_migration_overrides_runtime_memory_enablement() -> None: + config = _Config({"memory": {"enabled": True}, "schema": {"include_memory_migration": False}}) + + assert not _is_adk_memory_migration_enabled(config) From ee54548509417e094b33ce5671030bae59094284 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 21:10:28 +0000 Subject: [PATCH 08/29] feat(adk): add version planning model --- sqlspec/config.py | 13 ++ sqlspec/extensions/adk/_config_utils.py | 8 + sqlspec/extensions/adk/_versioning.py | 142 ++++++++++++++++++ .../extensions/test_adk/test_versioning.py | 82 ++++++++++ 4 files changed, 245 insertions(+) create mode 100644 sqlspec/extensions/adk/_versioning.py create mode 100644 tests/unit/extensions/test_adk/test_versioning.py diff --git a/sqlspec/config.py b/sqlspec/config.py index 775522c8b..7fac0ad1b 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -695,6 +695,7 @@ class ADKSchemaConfig(TypedDict): metadata_table: NotRequired[str] owner_id_column: NotRequired[str] schema_version: NotRequired[int] + payload_versions: NotRequired["ADKPayloadVersionsConfig"] include_sessions_migration: NotRequired[bool] include_memory_migration: NotRequired[bool] include_artifact_migration: NotRequired[bool] @@ -709,6 +710,15 @@ class ADKSearchConfig(TypedDict): max_results: NotRequired[int] +class ADKPayloadVersionsConfig(TypedDict): + """Shared ADK payload version pinning.""" + + event: NotRequired[int] + state: NotRequired[int] + memory: NotRequired[int] + artifact: NotRequired[int] + + class ADKMemoryConfig(TypedDict): """Shared ADK memory configuration.""" @@ -815,6 +825,9 @@ class ADKConfig(TypedDict): search: NotRequired[ADKSearchConfig] """Shared search strategy and language controls.""" + payloads: NotRequired[ADKPayloadVersionsConfig] + """Shared payload version pins.""" + artifact: NotRequired[ADKArtifactConfig] """Shared artifact metadata and storage URI controls.""" diff --git a/sqlspec/extensions/adk/_config_utils.py b/sqlspec/extensions/adk/_config_utils.py index f6ae820cf..e34780f55 100644 --- a/sqlspec/extensions/adk/_config_utils.py +++ b/sqlspec/extensions/adk/_config_utils.py @@ -6,6 +6,7 @@ from typing_extensions import NotRequired, TypedDict from sqlspec.exceptions import SQLSpecError +from sqlspec.extensions.adk._versioning import ADKVersionPlan, resolve_adk_version_plan from sqlspec.utils.module_loader import import_string __all__ = ( @@ -18,6 +19,7 @@ "_get_adk_memory_migration_store_class", "_get_adk_memory_store_config", "_get_adk_session_store_config", + "_get_adk_version_plan", "_is_adk_memory_migration_enabled", "_validate_adk_store_registration", ) @@ -160,6 +162,12 @@ def _get_adk_artifact_store_config(config: _ADKConfigSource) -> _ADKArtifactStor return result +def _get_adk_version_plan(config: _ADKConfigSource) -> ADKVersionPlan: + """Return normalized ADK schema and payload version settings.""" + + return resolve_adk_version_plan(_get_adk_config_from_extension(config)) + + def _resolve_adk_store_path(config: Any, store_suffix: str) -> str: """Return the adapter-specific ADK store import path.""" diff --git a/sqlspec/extensions/adk/_versioning.py b/sqlspec/extensions/adk/_versioning.py new file mode 100644 index 000000000..2ceb5ac06 --- /dev/null +++ b/sqlspec/extensions/adk/_versioning.py @@ -0,0 +1,142 @@ +"""ADK schema and payload version planning.""" + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Final, Literal, TypeAlias + +from sqlspec.exceptions import ImproperConfigurationError + +__all__ = ("ADKVersionPlan", "resolve_adk_version_plan", "validate_adk_version_plan") + + +ADKPayloadKind: TypeAlias = Literal["event", "state", "memory", "artifact"] + +ADK_SCHEMA_VERSION: Final = 1 +ADK_EVENT_PAYLOAD_VERSION: Final = 1 +ADK_STATE_PAYLOAD_VERSION: Final = 1 +ADK_MEMORY_PAYLOAD_VERSION: Final = 1 +ADK_ARTIFACT_PAYLOAD_VERSION: Final = 1 + +ADK_SCHEMA_VERSION_KEY: Final = "sqlspec.adk.schema_version" +ADK_PAYLOAD_VERSION_KEYS: Final[dict[ADKPayloadKind, str]] = { + "event": "sqlspec.adk.payload.event", + "state": "sqlspec.adk.payload.state", + "memory": "sqlspec.adk.payload.memory", + "artifact": "sqlspec.adk.payload.artifact", +} + +SUPPORTED_ADK_SCHEMA_VERSIONS: Final[frozenset[int]] = frozenset({ADK_SCHEMA_VERSION}) +SUPPORTED_ADK_PAYLOAD_VERSIONS: Final[dict[ADKPayloadKind, frozenset[int]]] = { + "event": frozenset({ADK_EVENT_PAYLOAD_VERSION}), + "state": frozenset({ADK_STATE_PAYLOAD_VERSION}), + "memory": frozenset({ADK_MEMORY_PAYLOAD_VERSION}), + "artifact": frozenset({ADK_ARTIFACT_PAYLOAD_VERSION}), +} + + +@dataclass(frozen=True, slots=True) +class ADKVersionPlan: + """Resolved ADK schema and payload version contract.""" + + schema_version: int = ADK_SCHEMA_VERSION + event_payload_version: int = ADK_EVENT_PAYLOAD_VERSION + state_payload_version: int = ADK_STATE_PAYLOAD_VERSION + memory_payload_version: int = ADK_MEMORY_PAYLOAD_VERSION + artifact_payload_version: int = ADK_ARTIFACT_PAYLOAD_VERSION + + def payload_versions(self) -> dict[ADKPayloadKind, int]: + """Return payload versions keyed by payload kind.""" + + return { + "event": self.event_payload_version, + "state": self.state_payload_version, + "memory": self.memory_payload_version, + "artifact": self.artifact_payload_version, + } + + def metadata_items(self) -> tuple[tuple[str, str], ...]: + """Return deterministic metadata rows for the ADK metadata table.""" + + return ( + (ADK_SCHEMA_VERSION_KEY, str(self.schema_version)), + (ADK_PAYLOAD_VERSION_KEYS["event"], str(self.event_payload_version)), + (ADK_PAYLOAD_VERSION_KEYS["state"], str(self.state_payload_version)), + (ADK_PAYLOAD_VERSION_KEYS["memory"], str(self.memory_payload_version)), + (ADK_PAYLOAD_VERSION_KEYS["artifact"], str(self.artifact_payload_version)), + ) + + +def resolve_adk_version_plan(adk_config: Mapping[str, object] | None = None) -> ADKVersionPlan: + """Resolve the configured ADK schema and payload versions.""" + + config = adk_config or {} + schema_config = _mapping(config.get("schema")) + schema_payloads = _mapping(schema_config.get("payload_versions")) + top_level_payloads = _mapping(config.get("payloads")) + plan = ADKVersionPlan( + schema_version=_version_value( + _first_value(schema_config.get("schema_version"), config.get("schema_version")), + default=ADK_SCHEMA_VERSION, + label="schema.schema_version", + ), + event_payload_version=_payload_version("event", schema_payloads, top_level_payloads), + state_payload_version=_payload_version("state", schema_payloads, top_level_payloads), + memory_payload_version=_payload_version("memory", schema_payloads, top_level_payloads), + artifact_payload_version=_payload_version("artifact", schema_payloads, top_level_payloads), + ) + validate_adk_version_plan(plan) + return plan + + +def validate_adk_version_plan(plan: ADKVersionPlan) -> None: + """Validate that a resolved ADK version plan is supported.""" + + if plan.schema_version not in SUPPORTED_ADK_SCHEMA_VERSIONS: + _raise_unsupported_version( + "schema", plan.schema_version, sorted(SUPPORTED_ADK_SCHEMA_VERSIONS), "schema.schema_version" + ) + for payload_kind, payload_version in plan.payload_versions().items(): + supported = SUPPORTED_ADK_PAYLOAD_VERSIONS[payload_kind] + if payload_version not in supported: + _raise_unsupported_version(payload_kind, payload_version, sorted(supported), "schema.payload_versions") + + +def _payload_version( + payload_kind: ADKPayloadKind, schema_payloads: Mapping[str, object], top_level_payloads: Mapping[str, object] +) -> int: + default_versions = { + "event": ADK_EVENT_PAYLOAD_VERSION, + "state": ADK_STATE_PAYLOAD_VERSION, + "memory": ADK_MEMORY_PAYLOAD_VERSION, + "artifact": ADK_ARTIFACT_PAYLOAD_VERSION, + } + return _version_value( + _first_value(schema_payloads.get(payload_kind), top_level_payloads.get(payload_kind)), + default=default_versions[payload_kind], + label=f"schema.payload_versions.{payload_kind}", + ) + + +def _mapping(value: object) -> Mapping[str, object]: + return value if isinstance(value, Mapping) else {} + + +def _first_value(*values: object) -> object | None: + for value in values: + if value is not None: + return value + return None + + +def _version_value(value: object | None, *, default: int, label: str) -> int: + if value is None: + return default + if type(value) is not int: + msg = f"ADK {label} must be an integer version, got {value!r}" + raise ImproperConfigurationError(msg) + return value + + +def _raise_unsupported_version(kind: str, version: int, supported: list[int], label: str) -> None: + msg = f"Unsupported ADK {kind} version {version!r} from {label}; supported versions: {supported}" + raise ImproperConfigurationError(msg) diff --git a/tests/unit/extensions/test_adk/test_versioning.py b/tests/unit/extensions/test_adk/test_versioning.py new file mode 100644 index 000000000..5c0019ad0 --- /dev/null +++ b/tests/unit/extensions/test_adk/test_versioning.py @@ -0,0 +1,82 @@ +"""Tests for ADK schema and payload version planning.""" + +from typing import Any + +import pytest + +from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.extensions.adk._config_utils import _get_adk_version_plan +from sqlspec.extensions.adk._versioning import ( + ADK_ARTIFACT_PAYLOAD_VERSION, + ADK_EVENT_PAYLOAD_VERSION, + ADK_MEMORY_PAYLOAD_VERSION, + ADK_PAYLOAD_VERSION_KEYS, + ADK_SCHEMA_VERSION, + ADK_SCHEMA_VERSION_KEY, + ADK_STATE_PAYLOAD_VERSION, + ADKVersionPlan, + validate_adk_version_plan, +) + + +class _Config: + extension_config: dict[str, dict[str, Any]] + + def __init__(self, adk_config: dict[str, Any]) -> None: + self.extension_config = {"adk": adk_config} + + +def test_default_version_plan_matches_clean_break_v1_contract() -> None: + plan = _get_adk_version_plan(_Config({})) + + assert plan == ADKVersionPlan( + schema_version=ADK_SCHEMA_VERSION, + event_payload_version=ADK_EVENT_PAYLOAD_VERSION, + state_payload_version=ADK_STATE_PAYLOAD_VERSION, + memory_payload_version=ADK_MEMORY_PAYLOAD_VERSION, + artifact_payload_version=ADK_ARTIFACT_PAYLOAD_VERSION, + ) + + +def test_version_plan_metadata_items_include_schema_and_payload_versions() -> None: + metadata_items = dict(_get_adk_version_plan(_Config({})).metadata_items()) + + assert metadata_items == { + ADK_SCHEMA_VERSION_KEY: "1", + ADK_PAYLOAD_VERSION_KEYS["event"]: "1", + ADK_PAYLOAD_VERSION_KEYS["state"]: "1", + ADK_PAYLOAD_VERSION_KEYS["memory"]: "1", + ADK_PAYLOAD_VERSION_KEYS["artifact"]: "1", + } + + +def test_nested_schema_payload_versions_override_defaults() -> None: + plan = _get_adk_version_plan( + _Config({ + "schema": {"schema_version": 1, "payload_versions": {"event": 1, "state": 1, "memory": 1, "artifact": 1}} + }) + ) + + assert plan.event_payload_version == 1 + assert plan.state_payload_version == 1 + assert plan.memory_payload_version == 1 + assert plan.artifact_payload_version == 1 + + +@pytest.mark.parametrize( + "adk_config", + [ + {"schema": {"schema_version": 2}}, + {"schema": {"payload_versions": {"event": 2}}}, + {"schema": {"payload_versions": {"state": 2}}}, + {"schema": {"payload_versions": {"memory": 2}}}, + {"schema": {"payload_versions": {"artifact": 2}}}, + ], +) +def test_unsupported_schema_or_payload_versions_raise_configuration_error(adk_config: dict[str, Any]) -> None: + with pytest.raises(ImproperConfigurationError): + _get_adk_version_plan(_Config(adk_config)) + + +def test_validate_adk_version_plan_accepts_supported_clean_break_plan() -> None: + validate_adk_version_plan(ADKVersionPlan()) From a9ed4ddd98ada7be8ca3b3ae725cc57239005536 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 21:15:16 +0000 Subject: [PATCH 09/29] feat(adk): define lifecycle controls --- sqlspec/config.py | 36 ++++++ sqlspec/extensions/adk/_config_utils.py | 8 ++ sqlspec/extensions/adk/_lifecycle.py | 105 ++++++++++++++++++ .../test_adk/test_lifecycle_config.py | 86 ++++++++++++++ 4 files changed, 235 insertions(+) create mode 100644 sqlspec/extensions/adk/_lifecycle.py create mode 100644 tests/unit/extensions/test_adk/test_lifecycle_config.py diff --git a/sqlspec/config.py b/sqlspec/config.py index 7fac0ad1b..68cd6bfca 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -719,6 +719,39 @@ class ADKPayloadVersionsConfig(TypedDict): artifact: NotRequired[int] +class ADKIndexingConfig(TypedDict): + """Shared ADK index lifecycle controls.""" + + generated_columns: NotRequired[ADKOptimizationMode] + covering_indexes: NotRequired[ADKOptimizationMode] + search_indexes: NotRequired[ADKOptimizationMode] + json_indexes: NotRequired[ADKOptimizationMode] + vector_indexes: NotRequired[ADKOptimizationMode] + + +class ADKTableOptionsConfig(TypedDict): + """Shared ADK table and index option attachment points.""" + + sessions: NotRequired[str] + events: NotRequired[str] + memory: NotRequired[str] + artifacts: NotRequired[str] + app_states: NotRequired[str] + user_states: NotRequired[str] + metadata: NotRequired[str] + expires_index: NotRequired[str] + + +class ADKLifecycleConfig(TypedDict): + """Shared ADK lifecycle controls for backend DDL chapters.""" + + partitioning: NotRequired[ADKPartitionConfig] + retention: NotRequired[ADKRetentionConfig] + indexing: NotRequired[ADKIndexingConfig] + compression: NotRequired[ADKCompressionConfig] + table_options: NotRequired[ADKTableOptionsConfig] + + class ADKMemoryConfig(TypedDict): """Shared ADK memory configuration.""" @@ -828,6 +861,9 @@ class ADKConfig(TypedDict): payloads: NotRequired[ADKPayloadVersionsConfig] """Shared payload version pins.""" + lifecycle: NotRequired[ADKLifecycleConfig] + """Shared lifecycle controls for partitioning, retention, indexing, compression, and table options.""" + artifact: NotRequired[ADKArtifactConfig] """Shared artifact metadata and storage URI controls.""" diff --git a/sqlspec/extensions/adk/_config_utils.py b/sqlspec/extensions/adk/_config_utils.py index e34780f55..ff1cf4fa7 100644 --- a/sqlspec/extensions/adk/_config_utils.py +++ b/sqlspec/extensions/adk/_config_utils.py @@ -6,6 +6,7 @@ from typing_extensions import NotRequired, TypedDict from sqlspec.exceptions import SQLSpecError +from sqlspec.extensions.adk._lifecycle import ADKLifecyclePlan, resolve_adk_lifecycle_plan from sqlspec.extensions.adk._versioning import ADKVersionPlan, resolve_adk_version_plan from sqlspec.utils.module_loader import import_string @@ -16,6 +17,7 @@ "_get_adk_adapter_store_class", "_get_adk_artifact_store_config", "_get_adk_config_from_extension", + "_get_adk_lifecycle_plan", "_get_adk_memory_migration_store_class", "_get_adk_memory_store_config", "_get_adk_session_store_config", @@ -168,6 +170,12 @@ def _get_adk_version_plan(config: _ADKConfigSource) -> ADKVersionPlan: return resolve_adk_version_plan(_get_adk_config_from_extension(config)) +def _get_adk_lifecycle_plan(config: _ADKConfigSource) -> ADKLifecyclePlan: + """Return normalized ADK lifecycle control settings.""" + + return resolve_adk_lifecycle_plan(_get_adk_config_from_extension(config)) + + def _resolve_adk_store_path(config: Any, store_suffix: str) -> str: """Return the adapter-specific ADK store import path.""" diff --git a/sqlspec/extensions/adk/_lifecycle.py b/sqlspec/extensions/adk/_lifecycle.py new file mode 100644 index 000000000..9f497e47b --- /dev/null +++ b/sqlspec/extensions/adk/_lifecycle.py @@ -0,0 +1,105 @@ +"""ADK lifecycle control resolution.""" + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Final, Literal, TypeAlias, cast + +from sqlspec.exceptions import ImproperConfigurationError + +__all__ = ("ADKLifecyclePlan", "resolve_adk_lifecycle_plan", "validate_adk_lifecycle_plan") + +ADKLifecycleMode: TypeAlias = Literal["auto", "enable", "disable"] + +ADK_INDEXING_CONTROLS: Final[tuple[str, ...]] = ( + "generated_columns", + "covering_indexes", + "search_indexes", + "json_indexes", + "vector_indexes", +) +ADK_LIFECYCLE_MODES: Final[frozenset[str]] = frozenset({"auto", "enable", "disable"}) + + +@dataclass(frozen=True, slots=True) +class ADKLifecyclePlan: + """Resolved ADK lifecycle controls used by backend DDL chapters.""" + + partitioning: dict[str, object] | None = None + retention: dict[str, object] | None = None + indexing: dict[str, ADKLifecycleMode] = field(default_factory=dict) + compression: dict[str, object] | None = None + table_options: dict[str, str] = field(default_factory=dict) + + +def resolve_adk_lifecycle_plan(adk_config: Mapping[str, object] | None = None) -> ADKLifecyclePlan: + """Resolve lifecycle controls from ADK extension config.""" + + config = adk_config or {} + lifecycle_config = _mapping(config.get("lifecycle")) + plan = ADKLifecyclePlan( + partitioning=_optional_mapping(_first_value(lifecycle_config.get("partitioning"), config.get("partitioning"))), + retention=_optional_mapping(_first_value(lifecycle_config.get("retention"), config.get("retention"))), + indexing=_resolve_indexing_config(config, lifecycle_config), + compression=_optional_mapping(_first_value(lifecycle_config.get("compression"), config.get("compression"))), + table_options=_resolve_table_options(config, lifecycle_config), + ) + validate_adk_lifecycle_plan(plan) + return plan + + +def validate_adk_lifecycle_plan(plan: ADKLifecyclePlan) -> None: + """Validate lifecycle controls that have shared semantics.""" + + for key, value in plan.indexing.items(): + if value not in ADK_LIFECYCLE_MODES: + msg = f"Unsupported ADK lifecycle indexing mode {value!r} for {key}; expected auto, enable, or disable" + raise ImproperConfigurationError(msg) + + +def _resolve_indexing_config( + config: Mapping[str, object], lifecycle_config: Mapping[str, object] +) -> dict[str, ADKLifecycleMode]: + lifecycle_indexing = _mapping(lifecycle_config.get("indexing")) + top_level_indexing = _mapping(config.get("indexing")) + optimizations = _mapping(config.get("optimizations")) + resolved: dict[str, ADKLifecycleMode] = {} + for key in ADK_INDEXING_CONTROLS: + value = _first_value(lifecycle_indexing.get(key), top_level_indexing.get(key), optimizations.get(key), "auto") + resolved[key] = _indexing_mode(key, value) + return resolved + + +def _resolve_table_options(config: Mapping[str, object], lifecycle_config: Mapping[str, object]) -> dict[str, str]: + flat_options = { + "sessions": config.get("session_table_options"), + "events": config.get("events_table_options"), + "memory": config.get("memory_table_options"), + "expires_index": config.get("expires_index_options"), + } + resolved = {key: str(value) for key, value in flat_options.items() if value is not None} + resolved.update({key: str(value) for key, value in _mapping(lifecycle_config.get("table_options")).items()}) + return resolved + + +def _mapping(value: object) -> Mapping[str, object]: + return value if isinstance(value, Mapping) else {} + + +def _optional_mapping(value: object | None) -> dict[str, object] | None: + if isinstance(value, Mapping): + return dict(value) + return None + + +def _first_value(*values: object) -> object | None: + for value in values: + if value is not None: + return value + return None + + +def _indexing_mode(key: str, value: object) -> ADKLifecycleMode: + if value in ADK_LIFECYCLE_MODES: + return cast("ADKLifecycleMode", value) + msg = f"Unsupported ADK lifecycle indexing mode {value!r} for {key}; expected auto, enable, or disable" + raise ImproperConfigurationError(msg) diff --git a/tests/unit/extensions/test_adk/test_lifecycle_config.py b/tests/unit/extensions/test_adk/test_lifecycle_config.py new file mode 100644 index 000000000..98585b3fc --- /dev/null +++ b/tests/unit/extensions/test_adk/test_lifecycle_config.py @@ -0,0 +1,86 @@ +"""Tests for ADK lifecycle control resolution.""" + +from typing import Any + +import pytest + +from sqlspec.config import ADKConfig +from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.extensions.adk._config_utils import _get_adk_lifecycle_plan + + +class _Config: + extension_config: dict[str, dict[str, Any]] + + def __init__(self, adk_config: dict[str, Any]) -> None: + self.extension_config = {"adk": adk_config} + + +def test_adk_config_declares_lifecycle_section() -> None: + assert "lifecycle" in ADKConfig.__annotations__ + + +def test_default_lifecycle_plan_sets_indexing_controls_to_auto() -> None: + plan = _get_adk_lifecycle_plan(_Config({})) + + assert plan.partitioning is None + assert plan.retention is None + assert plan.compression is None + assert plan.indexing == { + "generated_columns": "auto", + "covering_indexes": "auto", + "search_indexes": "auto", + "json_indexes": "auto", + "vector_indexes": "auto", + } + assert plan.table_options == {} + + +def test_nested_lifecycle_sections_override_flat_legacy_keys() -> None: + plan = _get_adk_lifecycle_plan( + _Config({ + "partitioning": {"strategy": "hash", "partition_count": 4}, + "retention": {"event_ttl_seconds": 60}, + "compression": {"enabled": False}, + "session_table_options": "flat-session-options", + "lifecycle": { + "partitioning": {"strategy": "range", "interval": "month"}, + "retention": {"event_ttl_seconds": 120}, + "indexing": {"generated_columns": "enable", "covering_indexes": "disable"}, + "compression": {"enabled": True, "algorithm": "zstd"}, + "table_options": {"sessions": "nested-session-options", "events": "nested-event-options"}, + }, + }) + ) + + assert plan.partitioning == {"strategy": "range", "interval": "month"} + assert plan.retention == {"event_ttl_seconds": 120} + assert plan.compression == {"enabled": True, "algorithm": "zstd"} + assert plan.indexing["generated_columns"] == "enable" + assert plan.indexing["covering_indexes"] == "disable" + assert plan.table_options == {"sessions": "nested-session-options", "events": "nested-event-options"} + + +def test_flat_table_options_are_normalized_when_lifecycle_options_are_absent() -> None: + plan = _get_adk_lifecycle_plan( + _Config({ + "session_table_options": "session-options", + "events_table_options": "event-options", + "memory_table_options": "memory-options", + "expires_index_options": "expires-options", + }) + ) + + assert plan.table_options == { + "sessions": "session-options", + "events": "event-options", + "memory": "memory-options", + "expires_index": "expires-options", + } + + +def test_invalid_lifecycle_indexing_mode_raises_configuration_error() -> None: + config = _Config({"lifecycle": {"indexing": {"generated_columns": "sometimes"}}}) + + with pytest.raises(ImproperConfigurationError): + _get_adk_lifecycle_plan(config) From ba5e5ffb5e3c1f5438e4362be887da6ad8c3ded2 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 21:19:35 +0000 Subject: [PATCH 10/29] feat(adk): define capability overrides --- sqlspec/config.py | 9 ++ sqlspec/extensions/adk/_capabilities.py | 101 ++++++++++++++++++ sqlspec/extensions/adk/_config_utils.py | 10 ++ .../extensions/test_adk/test_capabilities.py | 72 +++++++++++++ 4 files changed, 192 insertions(+) create mode 100644 sqlspec/extensions/adk/_capabilities.py create mode 100644 tests/unit/extensions/test_adk/test_capabilities.py diff --git a/sqlspec/config.py b/sqlspec/config.py index 68cd6bfca..22bebd7f5 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -752,6 +752,12 @@ class ADKLifecycleConfig(TypedDict): table_options: NotRequired[ADKTableOptionsConfig] +class ADKCapabilityConfig(TypedDict): + """Shared ADK capability detection overrides.""" + + overrides: NotRequired[dict[str, ADKOptimizationMode]] + + class ADKMemoryConfig(TypedDict): """Shared ADK memory configuration.""" @@ -864,6 +870,9 @@ class ADKConfig(TypedDict): lifecycle: NotRequired[ADKLifecycleConfig] """Shared lifecycle controls for partitioning, retention, indexing, compression, and table options.""" + capabilities: NotRequired[ADKCapabilityConfig] + """Shared detected capability overrides.""" + artifact: NotRequired[ADKArtifactConfig] """Shared artifact metadata and storage URI controls.""" diff --git a/sqlspec/extensions/adk/_capabilities.py b/sqlspec/extensions/adk/_capabilities.py new file mode 100644 index 000000000..3cf181703 --- /dev/null +++ b/sqlspec/extensions/adk/_capabilities.py @@ -0,0 +1,101 @@ +"""ADK capability detection and override resolution.""" + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Final, Literal, TypeAlias, cast + +from sqlspec.exceptions import ImproperConfigurationError + +__all__ = ( + "ADKCapabilityDecision", + "ADKCapabilityPlan", + "normalize_adk_capability_overrides", + "resolve_adk_capability_plan", +) + +ADKCapabilityMode: TypeAlias = Literal["auto", "enable", "disable"] +ADKCapabilitySource: TypeAlias = Literal["default", "detected", "override"] + +ADK_CAPABILITY_MODES: Final[frozenset[str]] = frozenset({"auto", "enable", "disable"}) + + +@dataclass(frozen=True, slots=True) +class ADKCapabilityDecision: + """Resolved decision for one ADK capability.""" + + feature: str + detected: bool | None + override: ADKCapabilityMode + enabled: bool + source: ADKCapabilitySource + reason: str | None = None + + +@dataclass(frozen=True, slots=True) +class ADKCapabilityPlan: + """Resolved ADK capability decisions keyed by feature name.""" + + decisions: dict[str, ADKCapabilityDecision] = field(default_factory=dict) + + def enabled_features(self) -> frozenset[str]: + """Return enabled feature names.""" + + return frozenset(feature for feature, decision in self.decisions.items() if decision.enabled) + + +def normalize_adk_capability_overrides(overrides: Mapping[str, object] | None = None) -> dict[str, ADKCapabilityMode]: + """Normalize capability overrides from ADK config.""" + + normalized: dict[str, ADKCapabilityMode] = {} + for feature, value in (overrides or {}).items(): + if value not in ADK_CAPABILITY_MODES: + msg = f"Unsupported ADK capability override {value!r} for {feature}; expected auto, enable, or disable" + raise ImproperConfigurationError(msg) + normalized[str(feature)] = cast("ADKCapabilityMode", value) + return normalized + + +def resolve_adk_capability_plan( + detected_features: Mapping[str, bool | None], overrides: Mapping[str, object] | None = None +) -> ADKCapabilityPlan: + """Resolve detected ADK capabilities with user overrides.""" + + normalized_overrides = normalize_adk_capability_overrides(overrides) + decisions: dict[str, ADKCapabilityDecision] = {} + for feature in sorted(set(detected_features) | set(normalized_overrides)): + detected = detected_features.get(feature) + override = normalized_overrides.get(feature, "auto") + decisions[feature] = _resolve_capability_decision(feature, detected, override) + return ADKCapabilityPlan(decisions=decisions) + + +def _resolve_capability_decision( + feature: str, detected: bool | None, override: ADKCapabilityMode +) -> ADKCapabilityDecision: + if override == "disable": + return ADKCapabilityDecision( + feature=feature, detected=detected, override=override, enabled=False, source="override" + ) + if override == "enable": + if detected is False: + msg = f"ADK capability {feature!r} was forced enabled but detection reported it as unsupported" + raise ImproperConfigurationError(msg) + return ADKCapabilityDecision( + feature=feature, detected=detected, override=override, enabled=True, source="override" + ) + if detected is True: + return ADKCapabilityDecision( + feature=feature, detected=detected, override=override, enabled=True, source="detected" + ) + if detected is False: + return ADKCapabilityDecision( + feature=feature, detected=detected, override=override, enabled=False, source="detected" + ) + return ADKCapabilityDecision( + feature=feature, + detected=detected, + override=override, + enabled=False, + source="default", + reason="capability was not detected", + ) diff --git a/sqlspec/extensions/adk/_config_utils.py b/sqlspec/extensions/adk/_config_utils.py index ff1cf4fa7..d8839b25c 100644 --- a/sqlspec/extensions/adk/_config_utils.py +++ b/sqlspec/extensions/adk/_config_utils.py @@ -6,6 +6,7 @@ from typing_extensions import NotRequired, TypedDict from sqlspec.exceptions import SQLSpecError +from sqlspec.extensions.adk._capabilities import ADKCapabilityMode, normalize_adk_capability_overrides from sqlspec.extensions.adk._lifecycle import ADKLifecyclePlan, resolve_adk_lifecycle_plan from sqlspec.extensions.adk._versioning import ADKVersionPlan, resolve_adk_version_plan from sqlspec.utils.module_loader import import_string @@ -16,6 +17,7 @@ "_ADKSessionStoreConfig", "_get_adk_adapter_store_class", "_get_adk_artifact_store_config", + "_get_adk_capability_overrides", "_get_adk_config_from_extension", "_get_adk_lifecycle_plan", "_get_adk_memory_migration_store_class", @@ -176,6 +178,14 @@ def _get_adk_lifecycle_plan(config: _ADKConfigSource) -> ADKLifecyclePlan: return resolve_adk_lifecycle_plan(_get_adk_config_from_extension(config)) +def _get_adk_capability_overrides(config: _ADKConfigSource) -> dict[str, ADKCapabilityMode]: + """Return normalized ADK capability overrides.""" + + adk_config = _get_adk_config_from_extension(config) + capabilities_config = _get_adk_config_section(adk_config, "capabilities") + return normalize_adk_capability_overrides(_get_adk_config_section(capabilities_config, "overrides")) + + def _resolve_adk_store_path(config: Any, store_suffix: str) -> str: """Return the adapter-specific ADK store import path.""" diff --git a/tests/unit/extensions/test_adk/test_capabilities.py b/tests/unit/extensions/test_adk/test_capabilities.py new file mode 100644 index 000000000..e8b9fbf85 --- /dev/null +++ b/tests/unit/extensions/test_adk/test_capabilities.py @@ -0,0 +1,72 @@ +"""Tests for ADK capability detection and override resolution.""" + +from typing import Any + +import pytest + +from sqlspec.config import ADKConfig +from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.extensions.adk._capabilities import resolve_adk_capability_plan +from sqlspec.extensions.adk._config_utils import _get_adk_capability_overrides + + +class _Config: + extension_config: dict[str, dict[str, Any]] + + def __init__(self, adk_config: dict[str, Any]) -> None: + self.extension_config = {"adk": adk_config} + + +def test_adk_config_declares_capabilities_section() -> None: + assert "capabilities" in ADKConfig.__annotations__ + + +def test_capability_plan_uses_detected_features_by_default() -> None: + plan = resolve_adk_capability_plan(detected_features={"supports_generated_columns": True, "supports_vector": False}) + + assert plan.decisions["supports_generated_columns"].enabled is True + assert plan.decisions["supports_generated_columns"].source == "detected" + assert plan.decisions["supports_vector"].enabled is False + assert plan.decisions["supports_vector"].source == "detected" + + +def test_disable_override_wins_over_detected_feature() -> None: + plan = resolve_adk_capability_plan( + detected_features={"supports_generated_columns": True}, overrides={"supports_generated_columns": "disable"} + ) + + decision = plan.decisions["supports_generated_columns"] + assert decision.enabled is False + assert decision.override == "disable" + assert decision.source == "override" + + +def test_enable_override_rejects_known_unsupported_feature() -> None: + with pytest.raises(ImproperConfigurationError, match="supports_vector"): + resolve_adk_capability_plan( + detected_features={"supports_vector": False}, overrides={"supports_vector": "enable"} + ) + + +def test_enable_override_can_force_unknown_detection_result() -> None: + plan = resolve_adk_capability_plan(detected_features={}, overrides={"supports_json_table": "enable"}) + + decision = plan.decisions["supports_json_table"] + assert decision.enabled is True + assert decision.detected is None + assert decision.source == "override" + + +def test_config_capability_overrides_are_normalized() -> None: + overrides = _get_adk_capability_overrides( + _Config({"capabilities": {"overrides": {"supports_generated_columns": "disable"}}}) + ) + + assert overrides == {"supports_generated_columns": "disable"} + + +def test_invalid_capability_override_raises_configuration_error() -> None: + with pytest.raises(ImproperConfigurationError): + _get_adk_capability_overrides( + _Config({"capabilities": {"overrides": {"supports_generated_columns": "sometimes"}}}) + ) From c5284fb530e7105c932e126680f40f912f6b457b Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 21:34:49 +0000 Subject: [PATCH 11/29] feat(adk): add base store ttl helpers --- sqlspec/extensions/adk/store.py | 35 +++++++++++++++++- .../extensions/test_adk/test_store_config.py | 37 ++++++++++++++++++- 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index e8d91f9e8..350d1df56 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -3,6 +3,7 @@ import logging import re from abc import ABC, abstractmethod +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar from sqlspec.extensions.adk._config_utils import _get_adk_session_store_config @@ -11,8 +12,6 @@ from sqlspec.utils.logging import get_logger, log_with_context if TYPE_CHECKING: - from datetime import datetime - from sqlspec.config import DatabaseConfigProtocol from sqlspec.extensions.adk._types import EventRecord, SessionRecord @@ -137,6 +136,38 @@ def owner_id_column_name(self) -> "str | None": """Return the owner ID column name only (or None if not configured).""" return self._owner_id_column_name + def _calculate_expires_at(self, expires_in: "int | timedelta | None") -> "datetime | None": + """Calculate expiration timestamp from expires_in. + + Args: + expires_in: Seconds or timedelta until expiration. + + Returns: + UTC datetime of expiration, or None if no expiration. + """ + if expires_in is None: + return None + + expires_in_seconds = int(expires_in.total_seconds()) if isinstance(expires_in, timedelta) else expires_in + + if expires_in_seconds <= 0: + return None + + return datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds) + + def _value_to_bytes(self, value: "str | bytes") -> bytes: + """Convert value to bytes if needed. + + Args: + value: String or bytes value. + + Returns: + Value as bytes. + """ + if isinstance(value, str): + return value.encode("utf-8") + return value + @abstractmethod async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None diff --git a/tests/unit/extensions/test_adk/test_store_config.py b/tests/unit/extensions/test_adk/test_store_config.py index fd35d2c3e..c44b73191 100644 --- a/tests/unit/extensions/test_adk/test_store_config.py +++ b/tests/unit/extensions/test_adk/test_store_config.py @@ -2,7 +2,7 @@ """Tests for shared ADK store configuration behavior.""" import logging -from datetime import datetime +from datetime import datetime, timedelta, timezone from typing import Any import pytest @@ -169,6 +169,41 @@ def test_memory_store_contract_exports_async_surface_only() -> None: assert not hasattr(memory_store_module, "BaseSyncADKMemoryStore") +@pytest.mark.parametrize("expires_in", [None, 0, timedelta(seconds=-5)]) +def test_async_session_store_calculate_expires_at_returns_none_for_non_positive_values( + expires_in: int | timedelta | None, +) -> None: + store = _AsyncSessionStore(_Config()) + + assert store._calculate_expires_at(expires_in) is None + + +@pytest.mark.parametrize("expires_in", [3600, timedelta(hours=1)]) +def test_async_session_store_calculate_expires_at_returns_utc_expiration(expires_in: int | timedelta) -> None: + store = _AsyncSessionStore(_Config()) + before = datetime.now(timezone.utc) + timedelta(seconds=3598) + + expires_at = store._calculate_expires_at(expires_in) + + after = datetime.now(timezone.utc) + timedelta(seconds=3602) + assert expires_at is not None + assert expires_at.tzinfo is timezone.utc + assert before <= expires_at <= after + + +def test_async_session_store_value_to_bytes_encodes_strings() -> None: + store = _AsyncSessionStore(_Config()) + + assert store._value_to_bytes("abc") == b"abc" + + +def test_async_session_store_value_to_bytes_returns_existing_bytes() -> None: + store = _AsyncSessionStore(_Config()) + value = b"abc" + + assert store._value_to_bytes(value) is value + + async def test_async_memory_store_logs_ready_with_log_with_context(monkeypatch: pytest.MonkeyPatch) -> None: calls: list[dict[str, Any]] = [] From 970da47fe470154f450feab77b3e531e71bfb4f3 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 21:49:36 +0000 Subject: [PATCH 12/29] feat(adk): add store cleanup hooks --- sqlspec/adapters/adbc/adk/store.py | 57 ++++++++++++++ sqlspec/adapters/aiomysql/adk/store.py | 32 ++++++++ sqlspec/adapters/aiosqlite/adk/store.py | 20 +++++ sqlspec/adapters/asyncmy/adk/store.py | 26 +++++++ sqlspec/adapters/asyncpg/adk/store.py | 20 +++++ .../adapters/cockroach_asyncpg/adk/store.py | 20 +++++ .../adapters/cockroach_psycopg/adk/store.py | 52 +++++++++++++ sqlspec/adapters/duckdb/adk/store.py | 45 +++++++++++ sqlspec/adapters/mysqlconnector/adk/store.py | 76 +++++++++++++++++++ sqlspec/adapters/oracledb/adk/store.py | 68 +++++++++++++++++ sqlspec/adapters/psqlpy/adk/store.py | 34 +++++++++ sqlspec/adapters/psycopg/adk/store.py | 60 +++++++++++++++ sqlspec/adapters/pymysql/adk/store.py | 42 ++++++++++ sqlspec/adapters/spanner/adk/store.py | 8 ++ sqlspec/adapters/sqlite/adk/store.py | 28 +++++++ sqlspec/extensions/adk/store.py | 24 ++++++ .../adapters/_adk_contract_helpers.py | 51 ++++++++++++- .../extensions/adk/test_session_operations.py | 6 ++ .../aiosqlite/extensions/adk/test_store.py | 14 +++- .../duckdb/extensions/adk/test_store.py | 10 ++- .../sqlite/extensions/adk/test_store.py | 14 +++- .../extensions/test_adk/test_store_config.py | 11 +++ 22 files changed, 714 insertions(+), 4 deletions(-) diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index 4b7b76e81..cc2def4db 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -855,6 +855,63 @@ async def get_events( """Get events for a session.""" return await async_(self._get_events)(session_id, after_timestamp, limit) + def _delete_expired_events(self, before: "datetime") -> int: + count_sql = f"SELECT COUNT(*) FROM {self._events_table} WHERE timestamp < ?" + delete_sql = f"DELETE FROM {self._events_table} WHERE timestamp < ?" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(count_sql, (before,)) + row = cursor.fetchone() + count = int(row[0]) if row else 0 + cursor.execute(delete_sql, (before,)) + conn.commit() + return count + finally: + cursor.close() + except Exception as e: + error_msg = str(e).lower() + if any(pattern in error_msg for pattern in ADBC_TABLE_NOT_FOUND_PATTERNS): + return 0 + raise + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + def _delete_idle_sessions(self, updated_before: "datetime") -> int: + count_sql = f"SELECT COUNT(*) FROM {self._session_table} WHERE update_time < ?" + delete_events_sql = f""" + DELETE FROM {self._events_table} + WHERE session_id IN (SELECT id FROM {self._session_table} WHERE update_time < ?) + """ + delete_sessions_sql = f"DELETE FROM {self._session_table} WHERE update_time < ?" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(count_sql, (updated_before,)) + row = cursor.fetchone() + count = int(row[0]) if row else 0 + cursor.execute(delete_events_sql, (updated_before,)) + cursor.execute(delete_sessions_sql, (updated_before,)) + conn.commit() + return count + finally: + cursor.close() + except Exception as e: + error_msg = str(e).lower() + if any(pattern in error_msg for pattern in ADBC_TABLE_NOT_FOUND_PATTERNS): + return 0 + raise + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) diff --git a/sqlspec/adapters/aiomysql/adk/store.py b/sqlspec/adapters/aiomysql/adk/store.py index 5b0c67b53..78c31f08d 100644 --- a/sqlspec/adapters/aiomysql/adk/store.py +++ b/sqlspec/adapters/aiomysql/adk/store.py @@ -466,6 +466,38 @@ async def get_events( return [] raise + async def delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" + + try: + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (before,)) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except pymysql.err.ProgrammingError as e: + if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" + + try: + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (updated_before,)) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except pymysql.err.ProgrammingError as e: + if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return 0 + raise + def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. diff --git a/sqlspec/adapters/aiosqlite/adk/store.py b/sqlspec/adapters/aiosqlite/adk/store.py index dfc12a9af..beb05ee35 100644 --- a/sqlspec/adapters/aiosqlite/adk/store.py +++ b/sqlspec/adapters/aiosqlite/adk/store.py @@ -545,6 +545,26 @@ async def get_events( return [] raise + async def delete_expired_events(self, before: datetime) -> int: + """Delete events older than the given timestamp.""" + sql = f"DELETE FROM {self._events_table} WHERE timestamp < ?" + + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + cursor = await conn.execute(sql, (_datetime_to_julian(before),)) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + + async def delete_idle_sessions(self, updated_before: datetime) -> int: + """Delete sessions whose update_time predates the given threshold.""" + sql = f"DELETE FROM {self._session_table} WHERE update_time < ?" + + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + cursor = await conn.execute(sql, (_datetime_to_julian(updated_before),)) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + class AiosqliteADKMemoryStore(BaseAsyncADKMemoryStore["AiosqliteConfig"]): """Aiosqlite ADK memory store using asynchronous SQLite driver. diff --git a/sqlspec/adapters/asyncmy/adk/store.py b/sqlspec/adapters/asyncmy/adk/store.py index 43364de36..5716035e2 100644 --- a/sqlspec/adapters/asyncmy/adk/store.py +++ b/sqlspec/adapters/asyncmy/adk/store.py @@ -441,6 +441,32 @@ async def get_events( return [] raise + async def delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (before,)) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except asyncmy.errors.ProgrammingError as e: # pyright: ignore[reportAttributeAccessIssue] + if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (updated_before,)) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except asyncmy.errors.ProgrammingError as e: # pyright: ignore[reportAttributeAccessIssue] + if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return 0 + raise + def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. diff --git a/sqlspec/adapters/asyncpg/adk/store.py b/sqlspec/adapters/asyncpg/adk/store.py index 21bf5e9ae..7e91aef24 100644 --- a/sqlspec/adapters/asyncpg/adk/store.py +++ b/sqlspec/adapters/asyncpg/adk/store.py @@ -287,6 +287,26 @@ async def get_events( except asyncpg.exceptions.UndefinedTableError: return [] + async def delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < $1" + + try: + async with self._config.provide_connection() as conn: + result = await conn.execute(sql, before) + return int(result.split()[-1]) if result else 0 + except asyncpg.exceptions.UndefinedTableError: + return 0 + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < $1" + + try: + async with self._config.provide_connection() as conn: + result = await conn.execute(sql, updated_before) + return int(result.split()[-1]) if result else 0 + except asyncpg.exceptions.UndefinedTableError: + return 0 + class AsyncpgADKMemoryStore(BaseAsyncADKMemoryStore["AsyncpgConfig"]): """PostgreSQL ADK memory store using asyncpg driver. diff --git a/sqlspec/adapters/cockroach_asyncpg/adk/store.py b/sqlspec/adapters/cockroach_asyncpg/adk/store.py index 31182fc0f..9c904187e 100644 --- a/sqlspec/adapters/cockroach_asyncpg/adk/store.py +++ b/sqlspec/adapters/cockroach_asyncpg/adk/store.py @@ -289,6 +289,26 @@ async def get_events( for row in rows ] + async def delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < $1" + + try: + async with self._config.provide_connection() as conn: + result = await conn.execute(sql, before) + return int(result.split()[-1]) if result else 0 + except asyncpg.exceptions.UndefinedTableError: + return 0 + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < $1" + + try: + async with self._config.provide_connection() as conn: + result = await conn.execute(sql, updated_before) + return int(result.split()[-1]) if result else 0 + except asyncpg.exceptions.UndefinedTableError: + return 0 + class CockroachAsyncpgADKMemoryStore(BaseAsyncADKMemoryStore["CockroachAsyncpgConfig"]): """CockroachDB ADK memory store using asyncpg driver.""" diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py index 4c45c1b7a..c912d816d 100644 --- a/sqlspec/adapters/cockroach_psycopg/adk/store.py +++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py @@ -352,6 +352,28 @@ async def get_events( except errors.UndefinedTable: return [] + async def delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(sql.encode(), (before,)) + await conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(sql.encode(), (updated_before,)) + await conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 + class CockroachPsycopgSyncADKStore(BaseAsyncADKStore["CockroachPsycopgSyncConfig"]): """CockroachDB ADK store using psycopg sync driver. @@ -680,6 +702,36 @@ async def get_events( """Get events for a session.""" return await async_(self._get_events)(session_id, after_timestamp, limit) + def _delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" + + try: + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(sql.encode(), (before,)) + conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + def _delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" + + try: + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(sql.encode(), (updated_before,)) + conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py index 995cd0e99..051316b02 100644 --- a/sqlspec/adapters/duckdb/adk/store.py +++ b/sqlspec/adapters/duckdb/adk/store.py @@ -564,6 +564,51 @@ async def get_events( """ return await async_(self._get_events)(session_id, after_timestamp, limit) + def _delete_expired_events(self, before: "datetime") -> int: + count_sql = f"SELECT COUNT(*) FROM {self._events_table} WHERE timestamp < ?" + delete_sql = f"DELETE FROM {self._events_table} WHERE timestamp < ?" + + try: + with self._config.provide_connection() as conn: + row = conn.execute(count_sql, (before,)).fetchone() + count = int(row[0]) if row else 0 + conn.execute(delete_sql, (before,)) + conn.commit() + return count + except Exception as e: + if DUCKDB_TABLE_NOT_FOUND_ERROR in str(e): + return 0 + raise + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + def _delete_idle_sessions(self, updated_before: "datetime") -> int: + count_sql = f"SELECT COUNT(*) FROM {self._session_table} WHERE update_time < ?" + delete_events_sql = f""" + DELETE FROM {self._events_table} + WHERE session_id IN (SELECT id FROM {self._session_table} WHERE update_time < ?) + """ + delete_sessions_sql = f"DELETE FROM {self._session_table} WHERE update_time < ?" + + try: + with self._config.provide_connection() as conn: + row = conn.execute(count_sql, (updated_before,)).fetchone() + count = int(row[0]) if row else 0 + conn.execute(delete_events_sql, (updated_before,)) + conn.execute(delete_sessions_sql, (updated_before,)) + conn.commit() + return count + except Exception as e: + if DUCKDB_TABLE_NOT_FOUND_ERROR in str(e): + return 0 + raise + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + class DuckdbADKMemoryStore(BaseAsyncADKMemoryStore["DuckDBConfig"]): """DuckDB ADK memory store using synchronous DuckDB driver with async wrappers. diff --git a/sqlspec/adapters/mysqlconnector/adk/store.py b/sqlspec/adapters/mysqlconnector/adk/store.py index 74576eae9..652e358cb 100644 --- a/sqlspec/adapters/mysqlconnector/adk/store.py +++ b/sqlspec/adapters/mysqlconnector/adk/store.py @@ -408,6 +408,40 @@ async def get_events( return [] raise + async def delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" + + try: + async with self._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute(sql, (before,)) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + finally: + await cursor.close() + except mysql.connector.Error as exc: + if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" + + try: + async with self._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute(sql, (updated_before,)) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + finally: + await cursor.close() + except mysql.connector.Error as exc: + if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + return 0 + raise + class MysqlConnectorSyncADKStore(BaseAsyncADKStore["MysqlConnectorSyncConfig"]): """MySQL/MariaDB ADK store using mysql-connector sync driver. @@ -774,6 +808,48 @@ async def get_events( """Get events for a session.""" return await async_(self._get_events)(session_id, after_timestamp, limit) + def _delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (before,)) + conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + finally: + cursor.close() + except mysql.connector.Error as exc: + if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + def _delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (updated_before,)) + conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + finally: + cursor.close() + except mysql.connector.Error as exc: + if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index 97b51b9a2..ae002a976 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -800,6 +800,36 @@ async def get_events( return [] raise + async def delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < :before" + + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"before": before}) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < :updated_before" + + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"updated_before": updated_before}) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return 0 + raise + class OracleSyncADKStore(BaseAsyncADKStore["OracleSyncConfig"]): """Oracle synchronous ADK store using oracledb sync driver. @@ -1529,6 +1559,44 @@ async def get_events( """Get events for a session.""" return await async_(self._get_events)(session_id, after_timestamp, limit) + def _delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < :before" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"before": before}) + conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + def _delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < :updated_before" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"updated_before": updated_before}) + conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" sql = f""" diff --git a/sqlspec/adapters/psqlpy/adk/store.py b/sqlspec/adapters/psqlpy/adk/store.py index a1b682cef..88d520143 100644 --- a/sqlspec/adapters/psqlpy/adk/store.py +++ b/sqlspec/adapters/psqlpy/adk/store.py @@ -310,6 +310,40 @@ async def get_events( return [] raise + async def delete_expired_events(self, before: "datetime") -> int: + count_sql = f"SELECT COUNT(*) AS count FROM {self._events_table} WHERE timestamp < $1" + delete_sql = f"DELETE FROM {self._events_table} WHERE timestamp < $1" + + try: + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + count_result = await conn.fetch(count_sql, [before]) + count_rows: list[dict[str, Any]] = count_result.result() if count_result else [] + count = int(count_rows[0]["count"]) if count_rows else 0 + await conn.execute(delete_sql, [before]) + return count + except psqlpy.exceptions.DatabaseError as e: + error_msg = str(e).lower() + if "does not exist" in error_msg or "relation" in error_msg: + return 0 + raise + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + count_sql = f"SELECT COUNT(*) AS count FROM {self._session_table} WHERE update_time < $1" + delete_sql = f"DELETE FROM {self._session_table} WHERE update_time < $1" + + try: + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + count_result = await conn.fetch(count_sql, [updated_before]) + count_rows: list[dict[str, Any]] = count_result.result() if count_result else [] + count = int(count_rows[0]["count"]) if count_rows else 0 + await conn.execute(delete_sql, [updated_before]) + return count + except psqlpy.exceptions.DatabaseError as e: + error_msg = str(e).lower() + if "does not exist" in error_msg or "relation" in error_msg: + return 0 + raise + PSQLPY_STATUS_REGEX: Final[re.Pattern[str]] = re.compile(r"^([A-Z]+)(?:\s+(\d+))?\s+(\d+)$", re.IGNORECASE) diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py index cf5135308..b6c4b4da1 100644 --- a/sqlspec/adapters/psycopg/adk/store.py +++ b/sqlspec/adapters/psycopg/adk/store.py @@ -348,6 +348,32 @@ async def get_events( except errors.UndefinedTable: return [] + async def delete_expired_events(self, before: "datetime") -> int: + query = pg_sql.SQL("DELETE FROM {table} WHERE timestamp < %s").format( + table=pg_sql.Identifier(self._events_table) + ) + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(query, (before,)) + await conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + query = pg_sql.SQL("DELETE FROM {table} WHERE update_time < %s").format( + table=pg_sql.Identifier(self._session_table) + ) + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(query, (updated_before,)) + await conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 + class PsycopgSyncADKStore(BaseAsyncADKStore["PsycopgSyncConfig"]): """PostgreSQL synchronous ADK store using Psycopg3 driver. @@ -684,6 +710,40 @@ async def get_events( """Get events for a session.""" return await async_(self._get_events)(session_id, after_timestamp, limit) + def _delete_expired_events(self, before: "datetime") -> int: + query = pg_sql.SQL("DELETE FROM {table} WHERE timestamp < %s").format( + table=pg_sql.Identifier(self._events_table) + ) + + try: + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(query, (before,)) + conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + def _delete_idle_sessions(self, updated_before: "datetime") -> int: + query = pg_sql.SQL("DELETE FROM {table} WHERE update_time < %s").format( + table=pg_sql.Identifier(self._session_table) + ) + + try: + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(query, (updated_before,)) + conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) diff --git a/sqlspec/adapters/pymysql/adk/store.py b/sqlspec/adapters/pymysql/adk/store.py index 60e7b095f..8f89e8442 100644 --- a/sqlspec/adapters/pymysql/adk/store.py +++ b/sqlspec/adapters/pymysql/adk/store.py @@ -437,6 +437,48 @@ async def get_events( """Get events for a session.""" return await async_(self._get_events)(session_id, after_timestamp, limit) + def _delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (before,)) + conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + finally: + cursor.close() + except pymysql.MySQLError as exc: + if "doesn't exist" in str(exc) or getattr(exc, "args", [None])[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + def _delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (updated_before,)) + conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + finally: + cursor.close() + except pymysql.MySQLError as exc: + if "doesn't exist" in str(exc) or getattr(exc, "args", [None])[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index 4234ec9ba..0f7861260 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -326,6 +326,14 @@ async def get_events( """Get events for a session.""" return await async_(self._get_events)(session_id, after_timestamp, limit) + async def delete_expired_events(self, before: "datetime") -> int: + """Return 0 because Spanner row deletion policies own TTL cleanup.""" + return 0 + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Return 0 because Spanner row deletion policies own TTL cleanup.""" + return 0 + def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py index 275cd4858..31de7756b 100644 --- a/sqlspec/adapters/sqlite/adk/store.py +++ b/sqlspec/adapters/sqlite/adk/store.py @@ -597,6 +597,34 @@ async def get_events( """ return await async_(self._get_events)(session_id, after_timestamp, limit) + def _delete_expired_events(self, before: datetime) -> int: + """Synchronous implementation of delete_expired_events.""" + sql = f"DELETE FROM {self._events_table} WHERE timestamp < ?" + + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + cursor = conn.execute(sql, (_datetime_to_julian(before),)) + conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + + async def delete_expired_events(self, before: datetime) -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + def _delete_idle_sessions(self, updated_before: datetime) -> int: + """Synchronous implementation of delete_idle_sessions.""" + sql = f"DELETE FROM {self._session_table} WHERE update_time < ?" + + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + cursor = conn.execute(sql, (_datetime_to_julian(updated_before),)) + conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + + async def delete_idle_sessions(self, updated_before: datetime) -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + class SqliteADKMemoryStore(BaseAsyncADKMemoryStore["SqliteConfig"]): """SQLite ADK memory store using synchronous SQLite driver. diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index 350d1df56..34f6046f6 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -281,6 +281,30 @@ async def get_events( """ raise NotImplementedError + @abstractmethod + async def delete_expired_events(self, before: datetime) -> int: + """Delete events older than the given timestamp. + + Args: + before: Timestamp threshold; events with timestamp earlier than this value are deleted. + + Returns: + Number of event rows deleted. + """ + raise NotImplementedError + + @abstractmethod + async def delete_idle_sessions(self, updated_before: datetime) -> int: + """Delete sessions whose update_time predates the given threshold. + + Args: + updated_before: Timestamp threshold; sessions updated earlier than this value are deleted. + + Returns: + Number of session rows deleted. + """ + raise NotImplementedError + @abstractmethod async def create_tables(self) -> None: """Create the sessions and events tables if they don't exist.""" diff --git a/tests/integration/adapters/_adk_contract_helpers.py b/tests/integration/adapters/_adk_contract_helpers.py index 6fd1e37a9..e46c44427 100644 --- a/tests/integration/adapters/_adk_contract_helpers.py +++ b/tests/integration/adapters/_adk_contract_helpers.py @@ -6,7 +6,11 @@ from sqlspec.extensions.adk import EventRecord, MemoryRecord, SessionRecord -__all__ = ("assert_memory_store_contract", "assert_session_event_store_contract") +__all__ = ( + "assert_memory_store_contract", + "assert_session_event_cleanup_contract", + "assert_session_event_store_contract", +) class SessionEventStore(Protocol): @@ -34,6 +38,10 @@ async def get_events( self, session_id: str, after_timestamp: datetime | None = None, limit: int | None = None ) -> list[EventRecord]: ... + async def delete_expired_events(self, before: datetime) -> int: ... + + async def delete_idle_sessions(self, updated_before: datetime) -> int: ... + class MemoryStore(Protocol): """Minimal ADK memory store surface used by contract tests.""" @@ -193,6 +201,47 @@ async def assert_session_event_store_contract(store: SessionEventStore, *, marke assert await store.get_events(session_id) == [] +async def assert_session_event_cleanup_contract(store: SessionEventStore, *, marker: str) -> None: + """Assert ADK session/event cleanup hooks remove only matching rows.""" + app_name = _contract_key(marker, "cleanup-app") + user_id = _contract_key(marker, "cleanup-user") + session_id = _contract_key(marker, "cleanup-session") + old_time = datetime(2026, 5, 1, 12, 0, tzinfo=timezone.utc) + new_time = datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc) + + await store.create_session(session_id, app_name, user_id, {"cleanup": True}) + await store.append_event( + _event_record( + session_id=session_id, + event_id="cleanup-old-event", + invocation_id="cleanup-old", + author="user", + timestamp=old_time, + event_data={"content": {"parts": [{"text": "old"}]}}, + ) + ) + await store.append_event( + _event_record( + session_id=session_id, + event_id="cleanup-new-event", + invocation_id="cleanup-new", + author="user", + timestamp=new_time, + event_data={"content": {"parts": [{"text": "new"}]}}, + ) + ) + + deleted_events = await store.delete_expired_events(datetime(2026, 5, 5, tzinfo=timezone.utc)) + assert deleted_events == 1 + remaining_events = await store.get_events(session_id) + assert [event["invocation_id"] for event in remaining_events] == ["cleanup-new"] + + deleted_sessions = await store.delete_idle_sessions(datetime(2100, 1, 1, tzinfo=timezone.utc)) + assert deleted_sessions == 1 + assert await store.get_session(session_id) is None + assert await store.get_events(session_id) == [] + + async def assert_memory_store_contract(store: MemoryStore, *, marker: str) -> None: """Assert the shared ADK memory store acceptance contract.""" app_name = _contract_key(marker, "app") diff --git a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py index b749461d7..0719cb3ef 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py @@ -7,6 +7,7 @@ from sqlspec.adapters.adbc import AdbcConfig from sqlspec.adapters.adbc.adk import AdbcADKStore +from tests.integration.adapters._adk_contract_helpers import assert_session_event_cleanup_contract pytestmark = [pytest.mark.xdist_group("sqlite"), pytest.mark.adbc, pytest.mark.integration] @@ -91,6 +92,11 @@ async def test_delete_session(adbc_store: Any) -> None: assert await adbc_store.get_session(session_id) is None +async def test_session_event_cleanup_contract(adbc_store: Any) -> None: + """ADBC satisfies the shared ADK cleanup hook contract on SQLite.""" + await assert_session_event_cleanup_contract(adbc_store, marker="adbc") + + async def test_list_sessions(adbc_store: Any) -> None: """Test listing sessions for an app and user.""" app_name = "test-app" diff --git a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py index ba9bdb174..bad01ff0d 100644 --- a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py @@ -8,7 +8,10 @@ from sqlspec.adapters.aiosqlite import AiosqliteConfig from sqlspec.adapters.aiosqlite.adk import AiosqliteADKStore from sqlspec.extensions.adk import EventRecord -from tests.integration.adapters._adk_contract_helpers import assert_session_event_store_contract +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_event_cleanup_contract, + assert_session_event_store_contract, +) pytestmark = pytest.mark.xdist_group("sqlite") @@ -64,6 +67,15 @@ async def test_aiosqlite_session_event_store_shared_contract(tmp_path: Path) -> await config.close_pool() +async def test_aiosqlite_session_event_cleanup_contract(tmp_path: Path) -> None: + """AioSQLite satisfies the shared ADK cleanup hook contract.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_event_cleanup_contract(store, marker="aiosqlite") + finally: + await config.close_pool() + + async def test_aiosqlite_append_event_and_update_state_is_atomic_contract(tmp_path: Path) -> None: """Event append and durable state update happen through the clean-break method.""" config, store = await _build_store(tmp_path) diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index 423e23a16..e62aca06d 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -10,7 +10,10 @@ from sqlspec.adapters.duckdb.adk import DuckdbADKStore from sqlspec.adapters.duckdb.config import DuckDBConfig from sqlspec.extensions.adk import EventRecord -from tests.integration.adapters._adk_contract_helpers import assert_session_event_store_contract +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_event_cleanup_contract, + assert_session_event_store_contract, +) pytestmark = [pytest.mark.duckdb, pytest.mark.integration] @@ -53,6 +56,11 @@ async def test_duckdb_session_event_store_shared_contract(duckdb_adk_store: Duck await assert_session_event_store_contract(duckdb_adk_store, marker="duckdb") +async def test_duckdb_session_event_cleanup_contract(duckdb_adk_store: DuckdbADKStore) -> None: + """DuckDB satisfies cleanup hooks while manually cascading event deletion.""" + await assert_session_event_cleanup_contract(duckdb_adk_store, marker="duckdb") + + async def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: """Test creating and retrieving a session.""" session_id = "session-001" diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_store.py index 7a22a3b3f..8f4380503 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_store.py @@ -8,7 +8,10 @@ from sqlspec.adapters.sqlite import SqliteConfig from sqlspec.adapters.sqlite.adk import SqliteADKStore from sqlspec.extensions.adk import EventRecord -from tests.integration.adapters._adk_contract_helpers import assert_session_event_store_contract +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_event_cleanup_contract, + assert_session_event_store_contract, +) pytestmark = pytest.mark.xdist_group("sqlite") @@ -44,6 +47,15 @@ async def test_sqlite_session_event_store_shared_contract(tmp_path: Path) -> Non config.close_pool() +async def test_sqlite_session_event_cleanup_contract(tmp_path: Path) -> None: + """SQLite satisfies the shared ADK cleanup hook contract.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_event_cleanup_contract(store, marker="sqlite") + finally: + config.close_pool() + + async def test_sqlite_append_event_and_update_state_is_atomic_contract(tmp_path: Path) -> None: """Event append and durable state update happen through the clean-break method.""" config, store = await _build_store(tmp_path) diff --git a/tests/unit/extensions/test_adk/test_store_config.py b/tests/unit/extensions/test_adk/test_store_config.py index c44b73191..42f054860 100644 --- a/tests/unit/extensions/test_adk/test_store_config.py +++ b/tests/unit/extensions/test_adk/test_store_config.py @@ -76,6 +76,12 @@ async def get_events( ) -> list[EventRecord]: return [] + async def delete_expired_events(self, before: datetime) -> int: + return 0 + + async def delete_idle_sessions(self, updated_before: datetime) -> int: + return 0 + async def create_tables(self) -> None: return None @@ -169,6 +175,11 @@ def test_memory_store_contract_exports_async_surface_only() -> None: assert not hasattr(memory_store_module, "BaseSyncADKMemoryStore") +def test_session_store_contract_declares_cleanup_hooks() -> None: + assert "delete_expired_events" in BaseAsyncADKStore.__abstractmethods__ + assert "delete_idle_sessions" in BaseAsyncADKStore.__abstractmethods__ + + @pytest.mark.parametrize("expires_in", [None, 0, timedelta(seconds=-5)]) def test_async_session_store_calculate_expires_at_returns_none_for_non_positive_values( expires_in: int | timedelta | None, From fadb8e7152fafb82b01f9a9205d96ac339e51f9e Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 22:13:32 +0000 Subject: [PATCH 13/29] feat(adk): renew sessions on read --- sqlspec/adapters/adbc/adk/store.py | 14 ++++-- sqlspec/adapters/aiomysql/adk/store.py | 12 ++++- sqlspec/adapters/aiosqlite/adk/store.py | 12 ++++- sqlspec/adapters/asyncmy/adk/store.py | 12 ++++- sqlspec/adapters/asyncpg/adk/store.py | 16 ++++-- .../adapters/cockroach_asyncpg/adk/store.py | 16 ++++-- .../adapters/cockroach_psycopg/adk/store.py | 50 +++++++++++++------ sqlspec/adapters/duckdb/adk/store.py | 16 ++++-- sqlspec/adapters/mysqlconnector/adk/store.py | 24 +++++++-- sqlspec/adapters/oracledb/adk/store.py | 30 +++++++++-- sqlspec/adapters/psqlpy/adk/store.py | 16 ++++-- sqlspec/adapters/psycopg/adk/store.py | 50 +++++++++++++------ sqlspec/adapters/pymysql/adk/store.py | 15 ++++-- sqlspec/adapters/spanner/adk/store.py | 18 +++++-- sqlspec/adapters/sqlite/adk/store.py | 16 ++++-- sqlspec/extensions/adk/service.py | 13 +++-- sqlspec/extensions/adk/store.py | 5 +- .../adapters/_adk_contract_helpers.py | 35 ++++++++++++- .../extensions/adk/test_session_operations.py | 10 +++- .../aiosqlite/extensions/adk/test_store.py | 10 ++++ .../duckdb/extensions/adk/test_store.py | 6 +++ .../sqlite/extensions/adk/test_store.py | 10 ++++ .../unit/extensions/test_adk/test_service.py | 29 +++++++++-- .../extensions/test_adk/test_store_config.py | 11 +++- 24 files changed, 362 insertions(+), 84 deletions(-) diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index cc2def4db..577331839 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -520,11 +520,12 @@ async def create_session( """Create a new session.""" return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str) -> "SessionRecord | None": + def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": """Get session by ID. Args: session_id: Session identifier. + renew_for: If positive, touch update_time while reading. Returns: Session record or None if not found. @@ -542,6 +543,11 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": with self._config.provide_connection() as conn: cursor = conn.cursor() try: + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + update_sql = f"UPDATE {self._session_table} SET update_time = ? WHERE id = ?" + cursor.execute(update_sql, (datetime.now(timezone.utc), session_id)) + conn.commit() + cursor.execute(sql, (session_id,)) row = cursor.fetchone() @@ -564,9 +570,11 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id) + return await async_(self._get_session)(session_id, renew_for) def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state. diff --git a/sqlspec/adapters/aiomysql/adk/store.py b/sqlspec/adapters/aiomysql/adk/store.py index 78c31f08d..0237013ab 100644 --- a/sqlspec/adapters/aiomysql/adk/store.py +++ b/sqlspec/adapters/aiomysql/adk/store.py @@ -11,7 +11,7 @@ from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.aiomysql.config import AiomysqlConfig from sqlspec.extensions.adk import MemoryRecord @@ -183,11 +183,14 @@ async def create_session( return await self.get_session(session_id) # type: ignore[return-value] - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID. Args: session_id: Session identifier. + renew_for: If positive, touch update_time while reading. Returns: Session record or None if not found. @@ -203,6 +206,11 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": self._config.provide_connection() as conn, AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, ): + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE id = %s" + await cursor.execute(update_sql, (session_id,)) + await conn.commit() + await cursor.execute(sql, (session_id,)) row = await cursor.fetchone() diff --git a/sqlspec/adapters/aiosqlite/adk/store.py b/sqlspec/adapters/aiosqlite/adk/store.py index beb05ee35..e95bce977 100644 --- a/sqlspec/adapters/aiosqlite/adk/store.py +++ b/sqlspec/adapters/aiosqlite/adk/store.py @@ -1,7 +1,7 @@ """Aiosqlite async ADK store for Google Agent Development Kit session/event storage.""" import sqlite3 -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Final, cast from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord @@ -251,11 +251,14 @@ async def create_session( id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now ) - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID. Args: session_id: Session identifier. + renew_for: If positive, touch update_time while reading. Returns: Session record or None if not found. @@ -273,6 +276,11 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": try: async with self._config.provide_connection() as conn: await self._apply_pragmas(conn) + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + update_sql = f"UPDATE {self._session_table} SET update_time = ? WHERE id = ?" + await conn.execute(update_sql, (_datetime_to_julian(datetime.now(timezone.utc)), session_id)) + await conn.commit() + cursor = await conn.execute(sql, (session_id,)) row = await cursor.fetchone() diff --git a/sqlspec/adapters/asyncmy/adk/store.py b/sqlspec/adapters/asyncmy/adk/store.py index 5716035e2..b843365ad 100644 --- a/sqlspec/adapters/asyncmy/adk/store.py +++ b/sqlspec/adapters/asyncmy/adk/store.py @@ -10,7 +10,7 @@ from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.asyncmy.config import AsyncmyConfig from sqlspec.extensions.adk import MemoryRecord @@ -179,11 +179,14 @@ async def create_session( return await self.get_session(session_id) # type: ignore[return-value] - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID. Args: session_id: Session identifier. + renew_for: If positive, touch update_time while reading. Returns: Session record or None if not found. @@ -196,6 +199,11 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": try: async with self._config.provide_connection() as conn, conn.cursor() as cursor: + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE id = %s" + await cursor.execute(update_sql, (session_id,)) + await conn.commit() + await cursor.execute(sql, (session_id,)) row = await cursor.fetchone() diff --git a/sqlspec/adapters/asyncpg/adk/store.py b/sqlspec/adapters/asyncpg/adk/store.py index 7e91aef24..06775350f 100644 --- a/sqlspec/adapters/asyncpg/adk/store.py +++ b/sqlspec/adapters/asyncpg/adk/store.py @@ -9,7 +9,7 @@ from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.asyncpg.config import AsyncpgConfig from sqlspec.extensions.adk import MemoryRecord @@ -115,8 +115,18 @@ async def create_session( return await self.get_session(session_id) # type: ignore[return-value] - async def get_session(self, session_id: str) -> "SessionRecord | None": - sql = f""" + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + sql = f""" + UPDATE {self._session_table} + SET update_time = CURRENT_TIMESTAMP + WHERE id = $1 + RETURNING id, app_name, user_id, state, create_time, update_time + """ + else: + sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} WHERE id = $1 diff --git a/sqlspec/adapters/cockroach_asyncpg/adk/store.py b/sqlspec/adapters/cockroach_asyncpg/adk/store.py index 9c904187e..942c539d6 100644 --- a/sqlspec/adapters/cockroach_asyncpg/adk/store.py +++ b/sqlspec/adapters/cockroach_asyncpg/adk/store.py @@ -9,7 +9,7 @@ from sqlspec.utils.logging import get_logger if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.cockroach_asyncpg.config import CockroachAsyncpgConfig from sqlspec.extensions.adk import MemoryRecord @@ -118,8 +118,18 @@ async def create_session( raise RuntimeError(msg) return result - async def get_session(self, session_id: str) -> "SessionRecord | None": - sql = f""" + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + sql = f""" + UPDATE {self._session_table} + SET update_time = CURRENT_TIMESTAMP + WHERE id = $1 + RETURNING id, app_name, user_id, state, create_time, update_time + """ + else: + sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} WHERE id = $1 diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py index c912d816d..bcab87774 100644 --- a/sqlspec/adapters/cockroach_psycopg/adk/store.py +++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py @@ -12,7 +12,7 @@ from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.cockroach_psycopg.config import CockroachPsycopgAsyncConfig, CockroachPsycopgSyncConfig from sqlspec.extensions.adk import MemoryRecord @@ -162,12 +162,22 @@ async def create_session( raise RuntimeError(msg) return result - async def get_session(self, session_id: str) -> "SessionRecord | None": - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = %s - """ + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + sql = f""" + UPDATE {self._session_table} + SET update_time = CURRENT_TIMESTAMP + WHERE id = %s + RETURNING id, app_name, user_id, state, create_time, update_time + """ + else: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE id = %s + """ try: async with self._config.provide_connection() as conn, conn.cursor() as cur: @@ -485,12 +495,20 @@ async def create_session( """Create a new session.""" return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str) -> "SessionRecord | None": - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = %s - """ + def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + sql = f""" + UPDATE {self._session_table} + SET update_time = CURRENT_TIMESTAMP + WHERE id = %s + RETURNING id, app_name, user_id, state, create_time, update_time + """ + else: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE id = %s + """ try: with self._config.provide_connection() as conn, conn.cursor() as cur: @@ -511,9 +529,11 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": except errors.UndefinedTable: return None - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id) + return await async_(self._get_session)(session_id, renew_for) def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: sql = f""" diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py index 051316b02..a216f9312 100644 --- a/sqlspec/adapters/duckdb/adk/store.py +++ b/sqlspec/adapters/duckdb/adk/store.py @@ -13,7 +13,7 @@ """ import contextlib -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Final, cast from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord @@ -256,7 +256,7 @@ async def create_session( """ return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str) -> "SessionRecord | None": + def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": """Synchronous implementation of get_session.""" sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -266,6 +266,11 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": try: with self._config.provide_connection() as conn: + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + update_sql = f"UPDATE {self._session_table} SET update_time = ? WHERE id = ?" + conn.execute(update_sql, (datetime.now(timezone.utc), session_id)) + conn.commit() + cursor = conn.execute(sql, (session_id,)) row = cursor.fetchone() @@ -289,11 +294,14 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID. Args: session_id: Session identifier. + renew_for: If positive, touch update_time while reading. Returns: Session record or None if not found. @@ -302,7 +310,7 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": DuckDB returns datetime objects for TIMESTAMPTZ columns. JSON is parsed from database storage. """ - return await async_(self._get_session)(session_id) + return await async_(self._get_session)(session_id, renew_for) def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Synchronous implementation of update_session_state.""" diff --git a/sqlspec/adapters/mysqlconnector/adk/store.py b/sqlspec/adapters/mysqlconnector/adk/store.py index 652e358cb..b92a6da93 100644 --- a/sqlspec/adapters/mysqlconnector/adk/store.py +++ b/sqlspec/adapters/mysqlconnector/adk/store.py @@ -11,7 +11,7 @@ from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.mysqlconnector.config import MysqlConnectorAsyncConfig, MysqlConnectorSyncConfig from sqlspec.extensions.adk import MemoryRecord @@ -146,7 +146,9 @@ async def create_session( return await self.get_session(session_id) # type: ignore[return-value] - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -157,6 +159,11 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": async with self._config.provide_connection() as conn: cursor = await conn.cursor() try: + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE id = %s" + await cursor.execute(update_sql, (session_id,)) + await conn.commit() + await cursor.execute(sql, (session_id,)) row = await cursor.fetchone() finally: @@ -523,7 +530,7 @@ async def create_session( """Create a new session.""" return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str) -> "SessionRecord | None": + def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -534,6 +541,11 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": with self._config.provide_connection() as conn: cursor = conn.cursor() try: + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE id = %s" + cursor.execute(update_sql, (session_id,)) + conn.commit() + cursor.execute(sql, (session_id,)) row = cursor.fetchone() finally: @@ -557,9 +569,11 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id) + return await async_(self._get_session)(session_id, renew_for) def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: state_json = to_json(state) diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index ae002a976..24ccf5c91 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -20,7 +20,7 @@ from sqlspec.utils.type_guards import is_async_readable, is_readable if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.oracledb.config import OracleAsyncConfig, OracleSyncConfig from sqlspec.extensions.adk import MemoryRecord @@ -498,11 +498,14 @@ async def create_session( return await self.get_session(session_id) # type: ignore[return-value] - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID. Args: session_id: Session identifier. + renew_for: If positive, touch update_time while reading. Returns: Session record or None if not found. @@ -515,6 +518,13 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": try: async with self._config.provide_connection() as conn: cursor = conn.cursor() + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + await cursor.execute( + f"UPDATE {self._session_table} SET update_time = SYSTIMESTAMP WHERE id = :id", + {"id": session_id}, + ) + await conn.commit() + await cursor.execute( f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -1264,11 +1274,12 @@ async def create_session( """Create a new session.""" return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str) -> "SessionRecord | None": + def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": """Get session by ID. Args: session_id: Session identifier. + renew_for: If positive, touch update_time while reading. Returns: Session record or None if not found. @@ -1287,6 +1298,13 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": try: with self._config.provide_connection() as conn: cursor = conn.cursor() + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + cursor.execute( + f"UPDATE {self._session_table} SET update_time = SYSTIMESTAMP WHERE id = :id", + {"id": session_id}, + ) + conn.commit() + cursor.execute(sql, {"id": session_id}) row = cursor.fetchone() @@ -1311,9 +1329,11 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id) + return await async_(self._get_session)(session_id, renew_for) def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state. diff --git a/sqlspec/adapters/psqlpy/adk/store.py b/sqlspec/adapters/psqlpy/adk/store.py index 88d520143..8cc83fcaa 100644 --- a/sqlspec/adapters/psqlpy/adk/store.py +++ b/sqlspec/adapters/psqlpy/adk/store.py @@ -11,7 +11,7 @@ from sqlspec.utils.type_guards import has_query_result_metadata if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.psqlpy.config import PsqlpyConfig from sqlspec.extensions.adk import MemoryRecord @@ -119,8 +119,18 @@ async def create_session( return await self.get_session(session_id) # type: ignore[return-value] - async def get_session(self, session_id: str) -> "SessionRecord | None": - sql = f""" + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + sql = f""" + UPDATE {self._session_table} + SET update_time = CURRENT_TIMESTAMP + WHERE id = $1 + RETURNING id, app_name, user_id, state, create_time, update_time + """ + else: + sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} WHERE id = $1 diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py index b6c4b4da1..0b4025d74 100644 --- a/sqlspec/adapters/psycopg/adk/store.py +++ b/sqlspec/adapters/psycopg/adk/store.py @@ -12,7 +12,7 @@ from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig, PsycopgSyncConfig from sqlspec.extensions.adk import MemoryRecord @@ -155,12 +155,22 @@ async def create_session( return await self.get_session(session_id) # type: ignore[return-value] - async def get_session(self, session_id: str) -> "SessionRecord | None": - query = pg_sql.SQL(""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {table} - WHERE id = %s - """).format(table=pg_sql.Identifier(self._session_table)) + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + query = pg_sql.SQL(""" + UPDATE {table} + SET update_time = CURRENT_TIMESTAMP + WHERE id = %s + RETURNING id, app_name, user_id, state, create_time, update_time + """).format(table=pg_sql.Identifier(self._session_table)) + else: + query = pg_sql.SQL(""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {table} + WHERE id = %s + """).format(table=pg_sql.Identifier(self._session_table)) try: async with self._config.provide_connection() as conn, conn.cursor() as cur: @@ -488,12 +498,20 @@ async def create_session( """Create a new session.""" return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str) -> "SessionRecord | None": - query = pg_sql.SQL(""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {table} - WHERE id = %s - """).format(table=pg_sql.Identifier(self._session_table)) + def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + query = pg_sql.SQL(""" + UPDATE {table} + SET update_time = CURRENT_TIMESTAMP + WHERE id = %s + RETURNING id, app_name, user_id, state, create_time, update_time + """).format(table=pg_sql.Identifier(self._session_table)) + else: + query = pg_sql.SQL(""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {table} + WHERE id = %s + """).format(table=pg_sql.Identifier(self._session_table)) try: with self._config.provide_connection() as conn, conn.cursor() as cur: @@ -514,9 +532,11 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": except errors.UndefinedTable: return None - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id) + return await async_(self._get_session)(session_id, renew_for) def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: query = pg_sql.SQL(""" diff --git a/sqlspec/adapters/pymysql/adk/store.py b/sqlspec/adapters/pymysql/adk/store.py index 8f89e8442..e7fbb0340 100644 --- a/sqlspec/adapters/pymysql/adk/store.py +++ b/sqlspec/adapters/pymysql/adk/store.py @@ -11,7 +11,7 @@ from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.pymysql.config import PyMysqlConfig from sqlspec.extensions.adk import MemoryRecord @@ -152,7 +152,7 @@ async def create_session( """Create a new session.""" return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str) -> "SessionRecord | None": + def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} @@ -163,6 +163,11 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": with self._config.provide_connection() as conn: cursor = conn.cursor() try: + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE id = %s" + cursor.execute(update_sql, (session_id,)) + conn.commit() + cursor.execute(sql, (session_id,)) row = cursor.fetchone() finally: @@ -186,9 +191,11 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id) + return await async_(self._get_session)(session_id, renew_for) def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: state_json = to_json(state) diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index 0f7861260..50a3233e3 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -123,7 +123,17 @@ async def create_session( """Create a new session.""" return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str) -> "SessionRecord | None": + def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + update_sql = f""" + UPDATE {self._session_table} + SET update_time = PENDING_COMMIT_TIMESTAMP() + WHERE id = @id + """ + if self._shard_count > 1: + update_sql = f"{update_sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" + self._run_write([(update_sql, {"id": session_id}, {"id": SPANNER_PARAM_TYPES.STRING})]) + sql = f""" SELECT id, app_name, user_id, state, create_time, update_time{", " + self._owner_id_column_name if self._owner_id_column_name else ""} FROM {self._session_table} @@ -149,9 +159,11 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": } return record - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id) + return await async_(self._get_session)(session_id, renew_for) def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: params = {"id": session_id, "state": to_json(state)} diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py index 31de7756b..6104e7f66 100644 --- a/sqlspec/adapters/sqlite/adk/store.py +++ b/sqlspec/adapters/sqlite/adk/store.py @@ -1,7 +1,7 @@ """SQLite sync ADK store for Google Agent Development Kit session/event storage.""" import sqlite3 -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Final from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord @@ -271,7 +271,7 @@ async def create_session( """ return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str) -> "SessionRecord | None": + def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": """Synchronous implementation of get_session.""" sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -282,6 +282,11 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": try: with self._config.provide_connection() as conn: self._apply_pragmas(conn) + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + update_sql = f"UPDATE {self._session_table} SET update_time = ? WHERE id = ?" + conn.execute(update_sql, (_datetime_to_julian(datetime.now(timezone.utc)), session_id)) + conn.commit() + cursor = conn.execute(sql, (session_id,)) row = cursor.fetchone() @@ -301,11 +306,14 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID. Args: session_id: Session identifier. + renew_for: If positive, touch update_time while reading. Returns: Session record or None if not found. @@ -314,7 +322,7 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": SQLite returns Julian Day (REAL) for timestamps. JSON is parsed from TEXT storage. """ - return await async_(self._get_session)(session_id) + return await async_(self._get_session)(session_id, renew_for) def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Synchronous implementation of update_session_state.""" diff --git a/sqlspec/extensions/adk/service.py b/sqlspec/extensions/adk/service.py index 9f7d2a36e..04205c484 100644 --- a/sqlspec/extensions/adk/service.py +++ b/sqlspec/extensions/adk/service.py @@ -2,7 +2,7 @@ import logging import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any from google.adk.sessions.base_session_service import BaseSessionService, GetSessionConfig, ListSessionsResponse @@ -97,7 +97,13 @@ async def create_session( return record_to_session(record, events=[]) async def get_session( - self, *, app_name: str, user_id: str, session_id: str, config: "GetSessionConfig | None" = None + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: "GetSessionConfig | None" = None, + renew_for: int | timedelta | None = None, ) -> "Session | None": """Get a session by ID. @@ -106,11 +112,12 @@ async def get_session( user_id: ID of the user. session_id: Session identifier. config: Configuration for retrieving events. + renew_for: If positive, touch the session update timestamp while reading. Returns: Session object if found, None otherwise. """ - record = await self._store.get_session(session_id) + record = await self._store.get_session(session_id, renew_for=renew_for) if not record: log_with_context( diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index 34f6046f6..971bcd90c 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -187,11 +187,14 @@ async def create_session( raise NotImplementedError @abstractmethod - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get a session by ID. Args: session_id: Session identifier. + renew_for: If positive, touch the session update timestamp while reading. Returns: Session record if found, None otherwise. diff --git a/tests/integration/adapters/_adk_contract_helpers.py b/tests/integration/adapters/_adk_contract_helpers.py index e46c44427..0ece74503 100644 --- a/tests/integration/adapters/_adk_contract_helpers.py +++ b/tests/integration/adapters/_adk_contract_helpers.py @@ -1,5 +1,6 @@ """Shared acceptance helpers for ADK adapter integration tests.""" +import asyncio from datetime import datetime, timedelta, timezone from typing import Protocol from uuid import uuid4 @@ -10,6 +11,7 @@ "assert_memory_store_contract", "assert_session_event_cleanup_contract", "assert_session_event_store_contract", + "assert_session_get_session_renewal_contract", ) @@ -20,7 +22,9 @@ async def create_session( self, session_id: str, app_name: str, user_id: str, state: dict[str, object], owner_id: object | None = None ) -> SessionRecord: ... - async def get_session(self, session_id: str) -> SessionRecord | None: ... + async def get_session( + self, session_id: str, *, renew_for: int | timedelta | None = None + ) -> SessionRecord | None: ... async def update_session_state(self, session_id: str, state: dict[str, object]) -> None: ... @@ -113,6 +117,14 @@ def _event_data(record: EventRecord) -> dict[str, object]: return value +def _as_utc(value: datetime | str) -> datetime: + if isinstance(value, str): + value = datetime.fromisoformat(value) + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + + async def assert_session_event_store_contract(store: SessionEventStore, *, marker: str) -> None: """Assert the shared ADK session/event store acceptance contract. @@ -201,6 +213,27 @@ async def assert_session_event_store_contract(store: SessionEventStore, *, marke assert await store.get_events(session_id) == [] +async def assert_session_get_session_renewal_contract(store: SessionEventStore, *, marker: str) -> None: + """Assert get_session can renew/touch update_time while reading.""" + app_name = _contract_key(marker, "renew-app") + user_id = _contract_key(marker, "renew-user") + session_id = _contract_key(marker, "renew-session") + + created = await store.create_session(session_id, app_name, user_id, {"renew": True}) + original_update_time = _as_utc(created["update_time"]) + await asyncio.sleep(0.02) + + before_renewal = datetime.now(timezone.utc) - timedelta(seconds=2) + renewed = await store.get_session(session_id, renew_for=timedelta(hours=1)) + after_renewal = datetime.now(timezone.utc) + timedelta(seconds=2) + + assert renewed is not None + renewed_update_time = _as_utc(renewed["update_time"]) + assert renewed_update_time > original_update_time + assert before_renewal <= renewed_update_time <= after_renewal + assert renewed["state"] == {"renew": True} + + async def assert_session_event_cleanup_contract(store: SessionEventStore, *, marker: str) -> None: """Assert ADK session/event cleanup hooks remove only matching rows.""" app_name = _contract_key(marker, "cleanup-app") diff --git a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py index 0719cb3ef..96aee2f40 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py @@ -7,7 +7,10 @@ from sqlspec.adapters.adbc import AdbcConfig from sqlspec.adapters.adbc.adk import AdbcADKStore -from tests.integration.adapters._adk_contract_helpers import assert_session_event_cleanup_contract +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_event_cleanup_contract, + assert_session_get_session_renewal_contract, +) pytestmark = [pytest.mark.xdist_group("sqlite"), pytest.mark.adbc, pytest.mark.integration] @@ -97,6 +100,11 @@ async def test_session_event_cleanup_contract(adbc_store: Any) -> None: await assert_session_event_cleanup_contract(adbc_store, marker="adbc") +async def test_session_get_session_renewal_contract(adbc_store: Any) -> None: + """ADBC can touch session update_time while reading a session.""" + await assert_session_get_session_renewal_contract(adbc_store, marker="adbc") + + async def test_list_sessions(adbc_store: Any) -> None: """Test listing sessions for an app and user.""" app_name = "test-app" diff --git a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py index bad01ff0d..402c54c23 100644 --- a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py @@ -11,6 +11,7 @@ from tests.integration.adapters._adk_contract_helpers import ( assert_session_event_cleanup_contract, assert_session_event_store_contract, + assert_session_get_session_renewal_contract, ) pytestmark = pytest.mark.xdist_group("sqlite") @@ -76,6 +77,15 @@ async def test_aiosqlite_session_event_cleanup_contract(tmp_path: Path) -> None: await config.close_pool() +async def test_aiosqlite_session_get_session_renewal_contract(tmp_path: Path) -> None: + """AioSQLite can touch session update_time while reading a session.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_get_session_renewal_contract(store, marker="aiosqlite") + finally: + await config.close_pool() + + async def test_aiosqlite_append_event_and_update_state_is_atomic_contract(tmp_path: Path) -> None: """Event append and durable state update happen through the clean-break method.""" config, store = await _build_store(tmp_path) diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index e62aca06d..960df345f 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -13,6 +13,7 @@ from tests.integration.adapters._adk_contract_helpers import ( assert_session_event_cleanup_contract, assert_session_event_store_contract, + assert_session_get_session_renewal_contract, ) pytestmark = [pytest.mark.duckdb, pytest.mark.integration] @@ -61,6 +62,11 @@ async def test_duckdb_session_event_cleanup_contract(duckdb_adk_store: DuckdbADK await assert_session_event_cleanup_contract(duckdb_adk_store, marker="duckdb") +async def test_duckdb_session_get_session_renewal_contract(duckdb_adk_store: DuckdbADKStore) -> None: + """DuckDB can touch session update_time while reading a session.""" + await assert_session_get_session_renewal_contract(duckdb_adk_store, marker="duckdb") + + async def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: """Test creating and retrieving a session.""" session_id = "session-001" diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_store.py index 8f4380503..b0a3b9dd1 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_store.py @@ -11,6 +11,7 @@ from tests.integration.adapters._adk_contract_helpers import ( assert_session_event_cleanup_contract, assert_session_event_store_contract, + assert_session_get_session_renewal_contract, ) pytestmark = pytest.mark.xdist_group("sqlite") @@ -56,6 +57,15 @@ async def test_sqlite_session_event_cleanup_contract(tmp_path: Path) -> None: config.close_pool() +async def test_sqlite_session_get_session_renewal_contract(tmp_path: Path) -> None: + """SQLite can touch session update_time while reading a session.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_get_session_renewal_contract(store, marker="sqlite") + finally: + config.close_pool() + + async def test_sqlite_append_event_and_update_state_is_atomic_contract(tmp_path: Path) -> None: """Event append and durable state update happen through the clean-break method.""" config, store = await _build_store(tmp_path) diff --git a/tests/unit/extensions/test_adk/test_service.py b/tests/unit/extensions/test_adk/test_service.py index 8a6bab71f..fb3ada2bb 100644 --- a/tests/unit/extensions/test_adk/test_service.py +++ b/tests/unit/extensions/test_adk/test_service.py @@ -10,7 +10,7 @@ """ import importlib.util -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import Any import pytest @@ -41,6 +41,7 @@ def __init__(self) -> None: self.append_event_and_update_state_calls: list[dict[str, Any]] = [] self.append_event_and_update_state_called = False self.get_session_calls = 0 + self.get_session_call_args: list[dict[str, Any]] = [] # Track calls to create_session self.create_session_calls: list[dict[str, Any]] = [] @@ -71,8 +72,11 @@ async def append_event_and_update_state( self._session_record = updated return updated - async def get_session(self, session_id: str) -> "dict[str, Any] | None": + async def get_session( + self, session_id: str, *, renew_for: int | timedelta | None = None + ) -> "dict[str, Any] | None": self.get_session_calls += 1 + self.get_session_call_args.append({"session_id": session_id, "renew_for": renew_for}) return self._session_record async def create_session( @@ -137,6 +141,19 @@ def _make_event( ) +@pytest.mark.anyio +async def test_get_session_forwards_renew_for_to_store() -> None: + """Session service exposes store-level read renewal for HTTP keep-alive callers.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + renew_for = timedelta(hours=1) + + session = await service.get_session(app_name="app", user_id="u1", session_id="s1", renew_for=renew_for) + + assert session is not None + assert store.get_session_call_args[0] == {"session_id": "s1", "renew_for": renew_for} + + # --------------------------------------------------------------------------- # append_event — calls append_event_and_update_state # --------------------------------------------------------------------------- @@ -367,7 +384,9 @@ def __init__(self, *, stale_marker: bool = False, stale_timestamp: bool = False) self._stale_marker = stale_marker self._stale_timestamp = stale_timestamp - async def get_session(self, session_id: str) -> "dict[str, Any] | None": + async def get_session( + self, session_id: str, *, renew_for: int | timedelta | None = None + ) -> "dict[str, Any] | None": record = dict(self._session_record) if self._stale_marker or self._stale_timestamp: # Simulate a storage-side update by advancing update_time @@ -380,7 +399,9 @@ async def get_session(self, session_id: str) -> "dict[str, Any] | None": class MissingSessionStore(MockStore): """Mock store where the session disappears between load and append.""" - async def get_session(self, session_id: str) -> "dict[str, Any] | None": + async def get_session( + self, session_id: str, *, renew_for: int | timedelta | None = None + ) -> "dict[str, Any] | None": return None diff --git a/tests/unit/extensions/test_adk/test_store_config.py b/tests/unit/extensions/test_adk/test_store_config.py index 42f054860..14770f8a7 100644 --- a/tests/unit/extensions/test_adk/test_store_config.py +++ b/tests/unit/extensions/test_adk/test_store_config.py @@ -1,6 +1,7 @@ # pyright: reportPrivateUsage=false """Tests for shared ADK store configuration behavior.""" +import inspect import logging from datetime import datetime, timedelta, timezone from typing import Any @@ -44,7 +45,7 @@ async def create_session( update_time=datetime.now(), ) - async def get_session(self, session_id: str) -> SessionRecord | None: + async def get_session(self, session_id: str, *, renew_for: int | timedelta | None = None) -> SessionRecord | None: return None async def update_session_state(self, session_id: str, state: dict[str, Any]) -> None: @@ -180,6 +181,14 @@ def test_session_store_contract_declares_cleanup_hooks() -> None: assert "delete_idle_sessions" in BaseAsyncADKStore.__abstractmethods__ +def test_session_store_contract_get_session_accepts_renew_for_kwarg() -> None: + signature = inspect.signature(BaseAsyncADKStore.get_session) + + parameter = signature.parameters["renew_for"] + assert parameter.kind is inspect.Parameter.KEYWORD_ONLY + assert parameter.default is None + + @pytest.mark.parametrize("expires_in", [None, 0, timedelta(seconds=-5)]) def test_async_session_store_calculate_expires_at_returns_none_for_non_positive_values( expires_in: int | timedelta | None, From 884d5c77b91d2a395b0668c9472e2a32788d3758 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 22:25:29 +0000 Subject: [PATCH 14/29] feat(adk): add table lifecycle controls --- sqlspec/extensions/adk/store.py | 61 ++++++++++++++++++- .../adapters/_adk_contract_helpers.py | 25 ++++++++ .../extensions/adk/test_session_operations.py | 6 ++ .../aiosqlite/extensions/adk/test_store.py | 10 +++ .../duckdb/extensions/adk/test_store.py | 6 ++ .../sqlite/extensions/adk/test_store.py | 10 +++ .../extensions/test_adk/test_store_config.py | 7 +++ 7 files changed, 124 insertions(+), 1 deletion(-) diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index 971bcd90c..dbfaef677 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -1,15 +1,17 @@ """Base store class for ADK session backends.""" +import inspect import logging import re from abc import ABC, abstractmethod from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar, cast from sqlspec.extensions.adk._config_utils import _get_adk_session_store_config from sqlspec.observability import resolve_db_system from sqlspec.utils.identifiers import validate_identifier from sqlspec.utils.logging import get_logger, log_with_context +from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: from sqlspec.config import DatabaseConfigProtocol @@ -319,6 +321,43 @@ async def ensure_tables(self) -> None: await self.create_tables() self._log_tables_created() + async def drop_tables(self) -> None: + """Drop all ADK tables managed by this store in FK-safe order.""" + await self._execute_lifecycle_scripts(self._get_drop_tables_sql()) + self._log_tables_dropped() + + async def recreate_tables(self) -> None: + """Drop and recreate all ADK tables managed by this store.""" + await self.drop_tables() + await self.ensure_tables() + self._log_tables_recreated() + + async def _execute_lifecycle_scripts(self, statements: list[str]) -> None: + """Execute lifecycle DDL scripts for async and sync-backed configs.""" + session_context = self._config.provide_session() + if hasattr(session_context, "__aenter__"): + async with cast("Any", session_context) as driver: + for statement in statements: + result = driver.execute_script(statement) + if inspect.isawaitable(result): + await result + commit = getattr(driver, "commit", None) + if callable(commit): + result = commit() + if inspect.isawaitable(result): + await result + return + + def _execute_sync() -> None: + with cast("Any", self._config.provide_session()) as driver: + for statement in statements: + driver.execute_script(statement) + commit = getattr(driver, "commit", None) + if callable(commit): + commit() + + await async_(_execute_sync)() + @abstractmethod async def _get_create_sessions_table_sql(self) -> str: """Get the CREATE TABLE SQL for the sessions table. @@ -360,3 +399,23 @@ def _log_tables_created(self) -> None: session_table=self._session_table, events_table=self._events_table, ) + + def _log_tables_dropped(self) -> None: + log_with_context( + logger, + logging.DEBUG, + "adk.tables.dropped", + db_system=resolve_db_system(type(self).__name__), + session_table=self._session_table, + events_table=self._events_table, + ) + + def _log_tables_recreated(self) -> None: + log_with_context( + logger, + logging.DEBUG, + "adk.tables.recreated", + db_system=resolve_db_system(type(self).__name__), + session_table=self._session_table, + events_table=self._events_table, + ) diff --git a/tests/integration/adapters/_adk_contract_helpers.py b/tests/integration/adapters/_adk_contract_helpers.py index 0ece74503..685a80e97 100644 --- a/tests/integration/adapters/_adk_contract_helpers.py +++ b/tests/integration/adapters/_adk_contract_helpers.py @@ -12,6 +12,7 @@ "assert_session_event_cleanup_contract", "assert_session_event_store_contract", "assert_session_get_session_renewal_contract", + "assert_session_table_lifecycle_contract", ) @@ -46,6 +47,10 @@ async def delete_expired_events(self, before: datetime) -> int: ... async def delete_idle_sessions(self, updated_before: datetime) -> int: ... + async def drop_tables(self) -> None: ... + + async def recreate_tables(self) -> None: ... + class MemoryStore(Protocol): """Minimal ADK memory store surface used by contract tests.""" @@ -234,6 +239,26 @@ async def assert_session_get_session_renewal_contract(store: SessionEventStore, assert renewed["state"] == {"renew": True} +async def assert_session_table_lifecycle_contract(store: SessionEventStore, *, marker: str) -> None: + """Assert ADK stores can drop and recreate their managed session tables.""" + app_name = _contract_key(marker, "lifecycle-app") + user_id = _contract_key(marker, "lifecycle-user") + session_id = _contract_key(marker, "lifecycle-session") + + await store.create_session(session_id, app_name, user_id, {"phase": "before"}) + assert await store.get_session(session_id) is not None + + await store.recreate_tables() + + assert await store.get_session(session_id) is None + recreated = await store.create_session(session_id, app_name, user_id, {"phase": "after"}) + assert recreated["state"] == {"phase": "after"} + + await store.drop_tables() + assert await store.get_session(session_id) is None + await store.drop_tables() + + async def assert_session_event_cleanup_contract(store: SessionEventStore, *, marker: str) -> None: """Assert ADK session/event cleanup hooks remove only matching rows.""" app_name = _contract_key(marker, "cleanup-app") diff --git a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py index 96aee2f40..d06c20425 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py @@ -10,6 +10,7 @@ from tests.integration.adapters._adk_contract_helpers import ( assert_session_event_cleanup_contract, assert_session_get_session_renewal_contract, + assert_session_table_lifecycle_contract, ) pytestmark = [pytest.mark.xdist_group("sqlite"), pytest.mark.adbc, pytest.mark.integration] @@ -105,6 +106,11 @@ async def test_session_get_session_renewal_contract(adbc_store: Any) -> None: await assert_session_get_session_renewal_contract(adbc_store, marker="adbc") +async def test_session_table_lifecycle_contract(adbc_store: Any) -> None: + """ADBC can drop and recreate its ADK session tables programmatically.""" + await assert_session_table_lifecycle_contract(adbc_store, marker="adbc") + + async def test_list_sessions(adbc_store: Any) -> None: """Test listing sessions for an app and user.""" app_name = "test-app" diff --git a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py index 402c54c23..7a0179103 100644 --- a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py @@ -12,6 +12,7 @@ assert_session_event_cleanup_contract, assert_session_event_store_contract, assert_session_get_session_renewal_contract, + assert_session_table_lifecycle_contract, ) pytestmark = pytest.mark.xdist_group("sqlite") @@ -86,6 +87,15 @@ async def test_aiosqlite_session_get_session_renewal_contract(tmp_path: Path) -> await config.close_pool() +async def test_aiosqlite_session_table_lifecycle_contract(tmp_path: Path) -> None: + """AioSQLite can drop and recreate its ADK session tables programmatically.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_table_lifecycle_contract(store, marker="aiosqlite") + finally: + await config.close_pool() + + async def test_aiosqlite_append_event_and_update_state_is_atomic_contract(tmp_path: Path) -> None: """Event append and durable state update happen through the clean-break method.""" config, store = await _build_store(tmp_path) diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index 960df345f..73088b2a1 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -14,6 +14,7 @@ assert_session_event_cleanup_contract, assert_session_event_store_contract, assert_session_get_session_renewal_contract, + assert_session_table_lifecycle_contract, ) pytestmark = [pytest.mark.duckdb, pytest.mark.integration] @@ -67,6 +68,11 @@ async def test_duckdb_session_get_session_renewal_contract(duckdb_adk_store: Duc await assert_session_get_session_renewal_contract(duckdb_adk_store, marker="duckdb") +async def test_duckdb_session_table_lifecycle_contract(duckdb_adk_store: DuckdbADKStore) -> None: + """DuckDB can drop and recreate its ADK session tables programmatically.""" + await assert_session_table_lifecycle_contract(duckdb_adk_store, marker="duckdb") + + async def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: """Test creating and retrieving a session.""" session_id = "session-001" diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_store.py index b0a3b9dd1..2444c59e0 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_store.py @@ -12,6 +12,7 @@ assert_session_event_cleanup_contract, assert_session_event_store_contract, assert_session_get_session_renewal_contract, + assert_session_table_lifecycle_contract, ) pytestmark = pytest.mark.xdist_group("sqlite") @@ -66,6 +67,15 @@ async def test_sqlite_session_get_session_renewal_contract(tmp_path: Path) -> No config.close_pool() +async def test_sqlite_session_table_lifecycle_contract(tmp_path: Path) -> None: + """SQLite can drop and recreate its ADK session tables programmatically.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_table_lifecycle_contract(store, marker="sqlite") + finally: + config.close_pool() + + async def test_sqlite_append_event_and_update_state_is_atomic_contract(tmp_path: Path) -> None: """Event append and durable state update happen through the clean-break method.""" config, store = await _build_store(tmp_path) diff --git a/tests/unit/extensions/test_adk/test_store_config.py b/tests/unit/extensions/test_adk/test_store_config.py index 14770f8a7..11c75d457 100644 --- a/tests/unit/extensions/test_adk/test_store_config.py +++ b/tests/unit/extensions/test_adk/test_store_config.py @@ -189,6 +189,13 @@ def test_session_store_contract_get_session_accepts_renew_for_kwarg() -> None: assert parameter.default is None +def test_session_store_contract_exposes_concrete_table_lifecycle_methods() -> None: + assert "drop_tables" not in BaseAsyncADKStore.__abstractmethods__ + assert "recreate_tables" not in BaseAsyncADKStore.__abstractmethods__ + assert inspect.iscoroutinefunction(BaseAsyncADKStore.drop_tables) + assert inspect.iscoroutinefunction(BaseAsyncADKStore.recreate_tables) + + @pytest.mark.parametrize("expires_in", [None, 0, timedelta(seconds=-5)]) def test_async_session_store_calculate_expires_at_returns_none_for_non_positive_values( expires_in: int | timedelta | None, From 9c810507dfe757374fe0416f8e3201bfbb0ff7d2 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 23 May 2026 22:30:34 +0000 Subject: [PATCH 15/29] feat(adk): validate scoped state tables --- sqlspec/extensions/adk/store.py | 32 ++++++++++++++++++- .../extensions/test_adk/test_store_config.py | 22 +++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index dbfaef677..37921f659 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -80,7 +80,16 @@ class BaseAsyncADKStore(ABC, Generic[ConfigT]): - owner_id_column: Optional owner FK column DDL (default: None) """ - __slots__ = ("_config", "_events_table", "_owner_id_column_ddl", "_owner_id_column_name", "_session_table") + __slots__ = ( + "_app_state_table", + "_config", + "_events_table", + "_metadata_table", + "_owner_id_column_ddl", + "_owner_id_column_name", + "_session_table", + "_user_state_table", + ) def __init__(self, config: ConfigT) -> None: """Initialize the ADK store. @@ -98,12 +107,18 @@ def __init__(self, config: ConfigT) -> None: store_config = self._get_store_config_from_extension() self._session_table: str = str(store_config["session_table"]) self._events_table: str = str(store_config["events_table"]) + self._app_state_table: str = str(store_config["app_state_table"]) + self._user_state_table: str = str(store_config["user_state_table"]) + self._metadata_table: str = str(store_config["metadata_table"]) self._owner_id_column_ddl: str | None = store_config.get("owner_id_column") self._owner_id_column_name: str | None = ( _parse_owner_id_column(self._owner_id_column_ddl) if self._owner_id_column_ddl else None ) validate_identifier(self._session_table, label="table name") validate_identifier(self._events_table, label="table name") + validate_identifier(self._app_state_table, label="table name") + validate_identifier(self._user_state_table, label="table name") + validate_identifier(self._metadata_table, label="table name") def _get_store_config_from_extension(self) -> "dict[str, Any]": """Extract ADK store configuration from config.extension_config. @@ -128,6 +143,21 @@ def events_table(self) -> str: """Return the events table name.""" return self._events_table + @property + def app_state_table(self) -> str: + """Return the app-scoped state table name.""" + return self._app_state_table + + @property + def user_state_table(self) -> str: + """Return the user-scoped state table name.""" + return self._user_state_table + + @property + def metadata_table(self) -> str: + """Return the ADK metadata table name.""" + return self._metadata_table + @property def owner_id_column_ddl(self) -> "str | None": """Return the full owner ID column DDL (or None if not configured).""" diff --git a/tests/unit/extensions/test_adk/test_store_config.py b/tests/unit/extensions/test_adk/test_store_config.py index 11c75d457..ea047fb6b 100644 --- a/tests/unit/extensions/test_adk/test_store_config.py +++ b/tests/unit/extensions/test_adk/test_store_config.py @@ -181,6 +181,28 @@ def test_session_store_contract_declares_cleanup_hooks() -> None: assert "delete_idle_sessions" in BaseAsyncADKStore.__abstractmethods__ +def test_session_store_resolves_schema_parity_table_names() -> None: + store = _AsyncSessionStore( + _Config({ + "schema": { + "app_state_table": "agent_app_states", + "user_state_table": "agent_user_states", + "metadata_table": "agent_metadata", + } + }) + ) + + assert store.app_state_table == "agent_app_states" + assert store.user_state_table == "agent_user_states" + assert store.metadata_table == "agent_metadata" + + +@pytest.mark.parametrize("field", ["app_state_table", "user_state_table", "metadata_table"]) +def test_session_store_validates_schema_parity_table_names(field: str) -> None: + with pytest.raises(ValueError, match="Invalid table name"): + _AsyncSessionStore(_Config({"schema": {field: "invalid-name"}})) + + def test_session_store_contract_get_session_accepts_renew_for_kwarg() -> None: signature = inspect.signature(BaseAsyncADKStore.get_session) From c49b7cb059f7ec6a6439c74858418a3478205a35 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 24 May 2026 05:39:42 +0000 Subject: [PATCH 16/29] feat: implement ADK scoped state tables --- .../examples/extensions/adk/backend_config.py | 7 +- docs/examples/extensions/adk/memory_store.py | 2 +- docs/extensions/adk/backends.rst | 7 +- docs/extensions/adk/index.rst | 7 + docs/extensions/adk/migrations.rst | 17 +- docs/extensions/adk/quickstart.rst | 4 +- docs/extensions/adk/schema.rst | 117 +- docs/extensions/adk/scoped_state.rst | 70 + sqlspec/adapters/adbc/adk/store.py | 416 +++- sqlspec/adapters/aiomysql/adk/store.py | 414 ++-- sqlspec/adapters/aiosqlite/adk/store.py | 484 ++-- sqlspec/adapters/asyncmy/adk/store.py | 396 ++-- sqlspec/adapters/asyncpg/adk/store.py | 282 ++- .../adapters/cockroach_asyncpg/adk/store.py | 261 ++- .../adapters/cockroach_psycopg/adk/store.py | 576 +++-- sqlspec/adapters/duckdb/adk/store.py | 505 +++-- sqlspec/adapters/mysqlconnector/adk/store.py | 574 ++++- sqlspec/adapters/oracledb/adk/store.py | 1999 ++++++++++------- sqlspec/adapters/psqlpy/adk/store.py | 354 ++- sqlspec/adapters/psycopg/adk/store.py | 643 ++++-- sqlspec/adapters/pymysql/adk/store.py | 328 ++- sqlspec/adapters/spanner/adk/store.py | 293 ++- sqlspec/adapters/sqlite/adk/store.py | 501 +++-- sqlspec/config.py | 13 +- sqlspec/extensions/adk/_config_utils.py | 8 +- sqlspec/extensions/adk/_versioning.py | 2 +- .../adk/migrations/0001_create_adk_tables.py | 4 + sqlspec/extensions/adk/service.py | 22 +- sqlspec/extensions/adk/store.py | 161 +- .../adapters/_adk_contract_helpers.py | 67 + .../extensions/adk/test_session_operations.py | 6 + .../aiomysql/extensions/adk/conftest.py | 19 +- .../aiomysql/extensions/adk/test_store.py | 6 + .../aiosqlite/extensions/adk/test_store.py | 12 +- .../asyncmy/extensions/adk/conftest.py | 19 +- .../asyncmy/extensions/adk/test_store.py | 6 + .../asyncpg/extensions/adk/conftest.py | 7 +- .../extensions/adk/test_owner_id_column.py | 34 +- .../extensions/adk/test_session_operations.py | 7 + .../duckdb/extensions/adk/test_store.py | 6 + .../mysqlconnector/extensions/adk/conftest.py | 19 +- .../extensions/adk/test_store.py | 8 + .../oracledb/extensions/adk/test_inmemory.py | 14 +- .../spanner/extensions/adk/conftest.py | 2 +- .../spanner/extensions/adk/test_adk_store.py | 6 + .../sqlite/extensions/adk/test_store.py | 10 + .../unit/extensions/test_adk/test_service.py | 65 +- .../extensions/test_adk/test_store_config.py | 111 + .../extensions/test_adk/test_versioning.py | 4 + .../test_events/test_events_config.py | 2 +- tests/unit/utils/test_identifiers.py | 10 +- 51 files changed, 6350 insertions(+), 2557 deletions(-) create mode 100644 docs/extensions/adk/scoped_state.rst diff --git a/docs/examples/extensions/adk/backend_config.py b/docs/examples/extensions/adk/backend_config.py index 13898f6ca..5a5ce6ff4 100644 --- a/docs/examples/extensions/adk/backend_config.py +++ b/docs/examples/extensions/adk/backend_config.py @@ -11,8 +11,11 @@ def test_adk_backend_config() -> None: from sqlspec.adapters.adbc import AdbcConfig adk_config = { - "session_table": "adk_sessions", - "events_table": "adk_events", + "session_table": "adk_session", + "events_table": "adk_event", + "app_state_table": "adk_app_state", + "user_state_table": "adk_user_state", + "metadata_table": "adk_internal_metadata", "memory_table": "adk_memory_entries", "memory_use_fts": True, } diff --git a/docs/examples/extensions/adk/memory_store.py b/docs/examples/extensions/adk/memory_store.py index 08c552471..265695261 100644 --- a/docs/examples/extensions/adk/memory_store.py +++ b/docs/examples/extensions/adk/memory_store.py @@ -29,7 +29,7 @@ async def _run() -> list[dict[str, object]]: "invocation_id": "inv_1", "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "evt_1", "invocation_id": "inv_1", "author": "user", diff --git a/docs/extensions/adk/backends.rst b/docs/extensions/adk/backends.rst index bc86733ba..577c522a5 100644 --- a/docs/extensions/adk/backends.rst +++ b/docs/extensions/adk/backends.rst @@ -298,8 +298,11 @@ All backends are configured through ``extension_config["adk"]``: connection_config={"dsn": "postgresql://localhost/mydb"}, extension_config={ "adk": { - "session_table": "adk_sessions", - "events_table": "adk_events", + "session_table": "adk_session", + "events_table": "adk_event", + "app_state_table": "adk_app_state", + "user_state_table": "adk_user_state", + "metadata_table": "adk_internal_metadata", "memory_table": "adk_memory_entries", "memory_use_fts": True, "owner_id_column": "tenant_id INTEGER NOT NULL", diff --git a/docs/extensions/adk/index.rst b/docs/extensions/adk/index.rst index 45d8f53b5..8d128409f 100644 --- a/docs/extensions/adk/index.rst +++ b/docs/extensions/adk/index.rst @@ -58,6 +58,12 @@ Choose a guide Table layouts, EventRecord, scoped state, and artifact metadata. + .. grid-item-card:: Scoped State + :link: scoped_state + :link-type: doc + + App, user, session, and runtime-only state behavior. + .. grid-item-card:: API Reference :link: api :link-type: doc @@ -78,5 +84,6 @@ Choose a guide backends adapters schema + scoped_state api migrations diff --git a/docs/extensions/adk/migrations.rst b/docs/extensions/adk/migrations.rst index aa0b522e6..20fae085f 100644 --- a/docs/extensions/adk/migrations.rst +++ b/docs/extensions/adk/migrations.rst @@ -8,7 +8,8 @@ used by your ADK backend, then run them with the SQLSpec migration CLI. Schema Bootstrapping ==================== -You can programmatically create ADK session/event and memory tables with +You can programmatically create ADK session/event/scoped-state/metadata and +memory tables with ``create_tables()`` / ``ensure_tables()``: .. code-block:: python @@ -55,13 +56,19 @@ If you are upgrading from a pre-clean-break version of the ADK extension, note the following schema changes: - **Events table**: The column layout changed to full-event JSON storage. - The ``event_json`` column now stores the entire ADK Event as a JSON blob. + The ``event_data`` column now stores the entire ADK Event as a JSON blob. Individual event columns (``content``, ``actions``, ``branch``, etc.) have been replaced by indexed scalar columns (``invocation_id``, ``author``, - ``timestamp``) plus ``event_json``. + ``timestamp``) plus ``event_data``. +- **Scoped state tables**: New ``adk_app_state`` and ``adk_user_state`` tables + store ``app:`` and ``user:`` scoped keys. Raw ``adk_session.state`` rows now + contain only session-scoped keys; ``SQLSpecSessionService.get_session()`` + returns the merged ADK view. +- **Internal metadata table**: New ``adk_internal_metadata`` table seeded with + ``schema_version = 1``. - **Artifact table**: New table (``adk_artifact_versions``) for artifact metadata. Create this table when enabling the artifact service. -- **BigQuery**: Removed. Migrate to Spanner, PostgreSQL, or any other - supported backend. +- **BigQuery**: Treated as an analytics-replica backend. Use Spanner or a + PostgreSQL-family adapter for latency-sensitive live session state. See :doc:`/usage/migrations` for the full workflow and commands. diff --git a/docs/extensions/adk/quickstart.rst b/docs/extensions/adk/quickstart.rst index 70c1ad321..63ac95359 100644 --- a/docs/extensions/adk/quickstart.rst +++ b/docs/extensions/adk/quickstart.rst @@ -28,8 +28,8 @@ When a user returns, the agent can resume from where it left off. connection_config={"dsn": "postgresql://localhost/mydb"}, extension_config={ "adk": { - "session_table": "adk_sessions", - "events_table": "adk_events", + "session_table": "adk_session", + "events_table": "adk_event", } }, ) diff --git a/docs/extensions/adk/schema.rst b/docs/extensions/adk/schema.rst index 92759ab84..31706dd48 100644 --- a/docs/extensions/adk/schema.rst +++ b/docs/extensions/adk/schema.rst @@ -2,9 +2,10 @@ Schema ====== -ADK stores create tables for sessions, events, and memory entries. The artifact -metadata table contract is documented for deployments that provide a concrete -artifact metadata store. Table names are configurable via +ADK stores create tables for sessions, events, scoped state, metadata, and +memory entries. The artifact metadata table contract is documented for +deployments that provide a concrete artifact metadata store. Table names are +configurable via ``extension_config["adk"]``. You can programmatically create the schema with ``create_tables()`` or @@ -24,7 +25,7 @@ Sessions Table The sessions table stores agent session metadata and durable state. -Default name: ``adk_sessions`` +Default name: ``adk_session`` .. list-table:: :header-rows: 1 @@ -60,13 +61,13 @@ Events Table (EventRecord) ========================== The events table uses **full-event JSON storage**: the entire ADK ``Event`` is -serialized into a single ``event_json`` column alongside a small set of indexed +serialized into a single ``event_data`` column alongside a small set of indexed scalar columns used for query filtering. This design eliminates column drift with upstream ADK releases. New ``Event`` -fields are automatically captured in ``event_json`` without schema changes. +fields are automatically captured in ``event_data`` without schema changes. -Default name: ``adk_events`` +Default name: ``adk_event`` .. list-table:: :header-rows: 1 @@ -86,7 +87,7 @@ Default name: ``adk_events`` * - ``timestamp`` - ``TIMESTAMP`` - Event timestamp (UTC, indexed for range queries). - * - ``event_json`` + * - ``event_data`` - ``JSONB`` / ``JSON`` / ``TEXT`` - Full ADK Event serialized via ``Event.model_dump()``. @@ -113,7 +114,8 @@ Scoped State Semantics ====================== ADK uses key prefixes to scope state visibility across sessions. SQLSpec -respects these prefixes when persisting and loading state. +respects these prefixes at the ``SQLSpecSessionService`` boundary when +persisting and loading state. .. list-table:: :header-rows: 1 @@ -145,10 +147,11 @@ respects these prefixes when persisting and loading state. initial ``INSERT``. 2. On ``append_event()``, the service calls ``filter_temp_state()`` to produce - a durable state snapshot, then calls ``append_event_and_update_state()`` to - atomically persist the event and the state update. + a durable state snapshot, splits prefixed keys into session/app/user + buckets, then calls the matching store hooks. -3. On ``get_session()``, state is loaded from the database. Since ``temp:`` +3. On ``get_session()``, the service loads the session row, app state row, and + user state row, then merges them into the ADK-visible state. Since ``temp:`` keys were never written, they are absent from the loaded state. .. code-block:: python @@ -172,6 +175,82 @@ respects these prefixes when persisting and loading state. # user_state: {"user:preferences": {"theme": "dark"}} # session_state: {"conversation_turn": 5} +App State Table +=============== + +The app state table stores ``app:`` keys shared by all sessions with the same +``app_name``. + +Default name: ``adk_app_state`` + +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``app_name`` + - ``VARCHAR`` / ``TEXT`` + - Primary key. Application identifier. + * - ``state`` + - ``JSONB`` / ``JSON`` / ``TEXT`` + - App-scoped state mapping. + * - ``update_time`` + - ``TIMESTAMP`` + - Last update time (UTC). + +User State Table +================ + +The user state table stores ``user:`` keys shared by sessions with the same +``app_name`` and ``user_id``. + +Default name: ``adk_user_state`` + +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``app_name`` + - ``VARCHAR`` / ``TEXT`` + - Application identifier. + * - ``user_id`` + - ``VARCHAR`` / ``TEXT`` + - User identifier. + * - ``state`` + - ``JSONB`` / ``JSON`` / ``TEXT`` + - User-scoped state mapping. + * - ``update_time`` + - ``TIMESTAMP`` + - Last update time (UTC). + +The composite primary key is ``(app_name, user_id)``. + +Internal Metadata Table +======================= + +The metadata table stores ADK schema metadata used by migrations and future +schema-version dispatch. + +Default name: ``adk_internal_metadata`` + +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``key`` + - ``VARCHAR`` / ``TEXT`` + - Primary key. + * - ``value`` + - ``VARCHAR`` / ``TEXT`` + - Metadata value. + +The initial ADK migration seeds ``schema_version = 1``. + .. _append-event-contract: The ``append_event_and_update_state()`` Contract @@ -181,10 +260,11 @@ This method is the **authoritative durable write boundary** for post-creation session mutations. It atomically: 1. Inserts the event record into the events table. -2. Updates the session's durable state in the sessions table. +2. Updates the session-scoped durable state in the sessions table. Both operations succeed together or fail together within a single database -transaction. +transaction. App/user scoped state writes are routed separately by the session +service through the dedicated scoped-state hooks. .. code-block:: python @@ -192,7 +272,7 @@ transaction. await store.append_event_and_update_state( event_record=event_record, session_id=session.id, - state=durable_state, # temp: keys already stripped + state=session_state, # temp/app/user keys already stripped ) **Why this matters:** @@ -302,8 +382,11 @@ All table names are configurable: connection_config={"dsn": "postgresql://..."}, extension_config={ "adk": { - "session_table": "my_sessions", # default: "adk_sessions" - "events_table": "my_events", # default: "adk_events" + "session_table": "my_session", # default: "adk_session" + "events_table": "my_event", # default: "adk_event" + "app_state_table": "my_app_state", # default: "adk_app_state" + "user_state_table": "my_user_state", # default: "adk_user_state" + "metadata_table": "my_adk_metadata", # default: "adk_internal_metadata" "memory_table": "my_memory", # default: "adk_memory_entries" "artifact_table": "my_artifacts", # artifact metadata stores } diff --git a/docs/extensions/adk/scoped_state.rst b/docs/extensions/adk/scoped_state.rst new file mode 100644 index 000000000..99be15c83 --- /dev/null +++ b/docs/extensions/adk/scoped_state.rst @@ -0,0 +1,70 @@ +============ +Scoped State +============ + +SQLSpec follows the ADK scoped-state prefixes and persists them through the +session service boundary. + +State Prefixes +============== + +.. list-table:: + :header-rows: 1 + + * - Prefix + - Scope + - Persistence + * - ``app:`` + - All sessions for the same ``app_name`` + - Stored in ``adk_app_state`` + * - ``user:`` + - Sessions for the same ``app_name`` and ``user_id`` + - Stored in ``adk_user_state`` + * - ``temp:`` + - Current runtime only + - Never written to storage + * - *(no prefix)* + - One session + - Stored in ``adk_session.state`` + +Write Behavior +============== + +``SQLSpecSessionService.create_session()`` and +``SQLSpecSessionService.append_event()`` strip ``temp:`` keys before +persistence. Durable keys are split into three buckets: + +- ``app:`` keys are written through ``upsert_app_state()``. +- ``user:`` keys are written through ``upsert_user_state()``. +- Unprefixed keys are written to the session row. + +``append_event_and_update_state()`` remains the store-level atomic boundary for +the event row and the session row. The scoped app/user writes are routed by the +service through the dedicated scoped-state store hooks. + +Read Behavior +============= + +``SQLSpecSessionService.get_session()`` reads the session state, app state, and +user state, then returns the merged ADK view. Raw store reads return only the +session row state, so direct database inspection shows ``app:`` and ``user:`` +keys in their dedicated tables rather than in ``adk_session.state``. + +.. code-block:: python + + session = await session_service.create_session( + app_name="agent", + user_id="user_1", + state={ + "app:model": "gemini", + "user:theme": "dark", + "turn": 1, + "temp:scratch": "...", + }, + ) + + assert session.state == { + "app:model": "gemini", + "user:theme": "dark", + "turn": 1, + } diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index 577331839..c550aa08d 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -87,8 +87,8 @@ def __init__(self, config: "AdbcConfig") -> None: Notes: Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") + - session_table: Sessions table name (default: "adk_session") + - events_table: Events table name (default: "adk_event") - owner_id_column: Optional owner FK column DDL (default: None) """ super().__init__(config) @@ -99,6 +99,82 @@ def dialect(self) -> str: """Return the detected database dialect.""" return self._dialect + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id, renew_for) + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> SessionRecord: + """Atomically append an event and update the session's durable state.""" + return await async_(self._append_event_and_update_state)(event_record, session_id, state) + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return await async_(self._get_app_state)(app_name) + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return await async_(self._get_user_state)(app_name, user_id) + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + await async_(self._upsert_app_state)(app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + await async_(self._upsert_user_state)(app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return await async_(self._get_metadata)(key) + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + await async_(self._set_metadata)(key, value) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + def _detect_dialect(self) -> str: """Detect ADBC driver dialect from connection config. @@ -177,11 +253,23 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": return None return from_json(str(data)) # type: ignore[no-any-return] + def _decode_timestamp(self, value: Any) -> datetime: + """Convert ADBC timestamp values to timezone-aware UTC datetimes.""" + if isinstance(value, datetime): + decoded = value + elif isinstance(value, (int, float)): + decoded = datetime.fromtimestamp(float(value), tz=timezone.utc) + else: + decoded = datetime.fromisoformat(str(value)) + if decoded.tzinfo is None: + return decoded.replace(tzinfo=timezone.utc) + return decoded.astimezone(timezone.utc) + async def _get_create_sessions_table_sql(self) -> str: """Get CREATE TABLE SQL for sessions with dialect dispatch. Returns: - SQL statement to create adk_sessions table. + SQL statement to create adk_session table. """ if self._dialect == DIALECT_POSTGRESQL: return self._get_sessions_ddl_postgresql() @@ -287,7 +375,7 @@ async def _get_create_events_table_sql(self) -> str: """Get CREATE TABLE SQL for events with dialect dispatch. Returns: - SQL statement to create adk_events table. + SQL statement to create adk_event table. """ if self._dialect == DIALECT_POSTGRESQL: return self._get_events_ddl_postgresql() @@ -399,6 +487,91 @@ def _get_events_ddl_generic(self) -> str: ) """ + async def _get_create_app_states_table_sql(self) -> str: + """Get CREATE TABLE SQL for app-scoped state with dialect dispatch.""" + if self._dialect == DIALECT_POSTGRESQL: + state_type = "JSONB" + time_type = "TIMESTAMPTZ" + elif self._dialect == DIALECT_DUCKDB: + state_type = "JSON" + time_type = "TIMESTAMP" + elif self._dialect == DIALECT_SNOWFLAKE: + state_type = "VARIANT" + time_type = "TIMESTAMP_TZ" + elif self._dialect == DIALECT_SQLITE: + state_type = "TEXT" + time_type = "TIMESTAMP" + else: + state_type = "TEXT" + time_type = "TIMESTAMP" + + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state {state_type} NOT NULL, + update_time {time_type} NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + """ + + async def _get_create_user_states_table_sql(self) -> str: + """Get CREATE TABLE SQL for user-scoped state with dialect dispatch.""" + if self._dialect == DIALECT_POSTGRESQL: + state_type = "JSONB" + time_type = "TIMESTAMPTZ" + elif self._dialect == DIALECT_DUCKDB: + state_type = "JSON" + time_type = "TIMESTAMP" + elif self._dialect == DIALECT_SNOWFLAKE: + state_type = "VARIANT" + time_type = "TIMESTAMP_TZ" + elif self._dialect == DIALECT_SQLITE: + state_type = "TEXT" + time_type = "TIMESTAMP" + else: + state_type = "TEXT" + time_type = "TIMESTAMP" + + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state {state_type} NOT NULL, + update_time {time_type} NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ) + """ + + async def _get_create_metadata_table_sql(self) -> str: + """Get CREATE TABLE SQL for ADK internal metadata.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ) + """ + + async def _get_seed_metadata_sql(self) -> str: + """Get SQL that seeds the ADK schema version metadata row.""" + return f""" + INSERT INTO {self._metadata_table} (key, value) + SELECT 'schema_version', '1' + WHERE NOT EXISTS ( + SELECT 1 FROM {self._metadata_table} WHERE key = 'schema_version' + ) + """ + + def _get_drop_app_states_table_sql(self) -> str: + """Get DROP TABLE SQL for app-scoped state.""" + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + """Get DROP TABLE SQL for user-scoped state.""" + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + """Get DROP TABLE SQL for ADK internal metadata.""" + return f"DROP TABLE IF EXISTS {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": """Get DROP TABLE SQL statements. @@ -409,7 +582,13 @@ def _get_drop_tables_sql(self) -> "list[str]": Order matters: drop events table (child) before sessions (parent). Most databases automatically drop indexes when dropping tables. """ - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] def _create_tables(self) -> None: """Create both sessions and events tables if they don't exist.""" @@ -444,13 +623,21 @@ def _create_tables(self) -> None: ) cursor.execute(events_idx) conn.commit() + + cursor.execute(run_(self._get_create_app_states_table_sql)()) + conn.commit() + + cursor.execute(run_(self._get_create_user_states_table_sql)()) + conn.commit() + + cursor.execute(run_(self._get_create_metadata_table_sql)()) + conn.commit() + + cursor.execute(run_(self._get_seed_metadata_sql)()) + conn.commit() finally: cursor.close() - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - def _enable_foreign_keys(self, cursor: Any, conn: Any) -> None: """Enable foreign key constraints for SQLite. @@ -514,12 +701,6 @@ def _create_session( raise RuntimeError(msg) return result - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": """Get session by ID. @@ -559,8 +740,8 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No app_name=row[1], user_id=row[2], state=self._deserialize_state(row[3]), - create_time=row[4], - update_time=row[5], + create_time=self._decode_timestamp(row[4]), + update_time=self._decode_timestamp(row[5]), ) finally: cursor.close() @@ -570,12 +751,6 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No return None raise - async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None - ) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state. @@ -602,10 +777,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non finally: cursor.close() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state.""" - await async_(self._update_session_state)(session_id, state) - def _delete_session(self, session_id: str) -> None: """Delete session and all associated events (cascade). @@ -626,10 +797,6 @@ def _delete_session(self, session_id: str) -> None: finally: cursor.close() - async def delete_session(self, session_id: str) -> None: - """Delete session and associated events.""" - await async_(self._delete_session)(session_id) - def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app, optionally filtered by user. @@ -673,8 +840,8 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses app_name=row[1], user_id=row[2], state=self._deserialize_state(row[3]), - create_time=row[4], - update_time=row[5], + create_time=self._decode_timestamp(row[4]), + update_time=self._decode_timestamp(row[5]), ) for row in rows ] @@ -686,10 +853,6 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) - def _insert_event(self, event_record: "EventRecord") -> None: """Insert an event record into the events table. @@ -788,16 +951,10 @@ def _append_event_and_update_state( app_name=row[1], user_id=row[2], state=self._deserialize_state(row[3]), - create_time=row[4], - update_time=row[5], + create_time=self._decode_timestamp(row[4]), + update_time=self._decode_timestamp(row[5]), ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -844,7 +1001,7 @@ def _get_events( session_id=row[0], invocation_id=row[1], author=row[2], - timestamp=row[3], + timestamp=self._decode_timestamp(row[3]), event_data=self._deserialize_json_field(row[4]) or {}, ) for row in rows @@ -857,12 +1014,6 @@ def _get_events( return [] raise - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) - def _delete_expired_events(self, before: "datetime") -> int: count_sql = f"SELECT COUNT(*) FROM {self._events_table} WHERE timestamp < ?" delete_sql = f"DELETE FROM {self._events_table} WHERE timestamp < ?" @@ -885,10 +1036,6 @@ def _delete_expired_events(self, before: "datetime") -> int: return 0 raise - async def delete_expired_events(self, before: "datetime") -> int: - """Delete events older than the given timestamp.""" - return await async_(self._delete_expired_events)(before) - def _delete_idle_sessions(self, updated_before: "datetime") -> int: count_sql = f"SELECT COUNT(*) FROM {self._session_table} WHERE update_time < ?" delete_events_sql = f""" @@ -916,18 +1063,117 @@ def _delete_idle_sessions(self, updated_before: "datetime") -> int: return 0 raise - async def delete_idle_sessions(self, updated_before: "datetime") -> int: - """Delete sessions whose update_time predates the given threshold.""" - return await async_(self._delete_idle_sessions)(updated_before) + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = ?" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (app_name,)) + row = cursor.fetchone() + return self._deserialize_state(row[0]) if row is not None else None + finally: + cursor.close() + except Exception as e: + error_msg = str(e).lower() + if any(pattern in error_msg for pattern in ADBC_TABLE_NOT_FOUND_PATTERNS): + return None + raise + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = ? AND user_id = ?" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (app_name, user_id)) + row = cursor.fetchone() + return self._deserialize_state(row[0]) if row is not None else None + finally: + cursor.close() + except Exception as e: + error_msg = str(e).lower() + if any(pattern in error_msg for pattern in ADBC_TABLE_NOT_FOUND_PATTERNS): + return None + raise + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + delete_sql = f"DELETE FROM {self._app_state_table} WHERE app_name = ?" + insert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(delete_sql, (app_name,)) + cursor.execute(insert_sql, (app_name, self._serialize_state(state), datetime.now(timezone.utc))) + conn.commit() + finally: + cursor.close() + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + delete_sql = f"DELETE FROM {self._user_state_table} WHERE app_name = ? AND user_id = ?" + insert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(delete_sql, (app_name, user_id)) + cursor.execute( + insert_sql, (app_name, user_id, self._serialize_state(state), datetime.now(timezone.utc)) + ) + conn.commit() + finally: + cursor.close() + + def _get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE key = ?" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (key,)) + row = cursor.fetchone() + return row[0] if row is not None else None + finally: + cursor.close() + except Exception as e: + error_msg = str(e).lower() + if any(pattern in error_msg for pattern in ADBC_TABLE_NOT_FOUND_PATTERNS): + return None + raise + + def _set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + delete_sql = f"DELETE FROM {self._metadata_table} WHERE key = ?" + insert_sql = f"INSERT INTO {self._metadata_table} (key, value) VALUES (?, ?)" + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(delete_sql, (key,)) + cursor.execute(insert_sql, (key, value)) + conn.commit() + finally: + cursor.close() def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) - class AdbcADKMemoryStore(BaseAsyncADKMemoryStore["AdbcConfig"]): """ADBC synchronous ADK memory store for Arrow Database Connectivity.""" @@ -942,6 +1188,28 @@ def __init__(self, config: "AdbcConfig") -> None: def dialect(self) -> str: return self._dialect + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + def _detect_dialect(self) -> str: driver_name = self._config.connection_config.get("driver_name", "").lower() if "postgres" in driver_name: @@ -1103,10 +1371,6 @@ def _create_tables(self) -> None: finally: cursor.close() - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -1213,10 +1477,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -1257,12 +1517,6 @@ def _search_entries( return self._rows_to_records(rows) - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _delete_entries_by_session(self, session_id: str) -> int: use_returning = self._dialect in {DIALECT_SQLITE, DIALECT_POSTGRESQL, DIALECT_DUCKDB} if use_returning: @@ -1282,10 +1536,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: cutoff = self._encode_timestamp(datetime.now(timezone.utc) - timedelta(days=days)) use_returning = self._dialect in {DIALECT_SQLITE, DIALECT_POSTGRESQL, DIALECT_DUCKDB} @@ -1306,10 +1556,6 @@ def _delete_entries_older_than(self, days: int) -> int: finally: cursor.close() - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) - def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": records: list[MemoryRecord] = [] for row in rows: diff --git a/sqlspec/adapters/aiomysql/adk/store.py b/sqlspec/adapters/aiomysql/adk/store.py index 0237013ab..8a81c5df2 100644 --- a/sqlspec/adapters/aiomysql/adk/store.py +++ b/sqlspec/adapters/aiomysql/adk/store.py @@ -22,6 +22,26 @@ MYSQL_TABLE_NOT_FOUND_ERROR: Final = 1146 +def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": + """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. + + Args: + column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE". + + Returns: + Tuple of (column_definition, foreign_key_constraint). + """ + references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) + if not references_match: + return (column_ddl.strip(), "") + + col_def = column_ddl[: references_match.start()].strip() + fk_clause = references_match.group(1).strip() + col_name = col_def.split()[0] + fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" + return (col_def, fk_constraint) + + class AiomysqlADKStore(BaseAsyncADKStore["AiomysqlConfig"]): """MySQL/MariaDB ADK store using aiomysql driver. @@ -51,97 +71,15 @@ def __init__(self, config: "AiomysqlConfig") -> None: """ super().__init__(config) - def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": - """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. - - MySQL ignores inline REFERENCES syntax in column definitions. - This method extracts the column definition and creates a separate - FOREIGN KEY constraint. - - Args: - column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - - Returns: - Tuple of (column_definition, foreign_key_constraint) - """ - references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) - - if not references_match: - return (column_ddl.strip(), "") - - col_def = column_ddl[: references_match.start()].strip() - fk_clause = references_match.group(1).strip() - col_name = col_def.split()[0] - fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" - - return (col_def, fk_constraint) - - async def _get_create_sessions_table_sql(self) -> str: - """Get MySQL CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - """ - owner_id_col = "" - fk_constraint = "" - - if self._owner_id_column_ddl: - col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_col = f"{col_def}," - if fk_def: - fk_constraint = f",\n {fk_def}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - {owner_id_col} - state JSON NOT NULL, - create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), - INDEX idx_{self._session_table}_app_user (app_name, user_id), - INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - async def _get_create_events_table_sql(self) -> str: - """Get MySQL CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - - Notes: - Post clean-break schema: 5 columns only. - - session_id, invocation_id, author: indexed scalars - - timestamp: microsecond-precision TIMESTAMP - - event_data: full Event as native JSON - """ - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(128) NOT NULL, - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - event_data JSON NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, - INDEX idx_{self._events_table}_session (session_id, timestamp ASC) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - def _get_drop_tables_sql(self) -> "list[str]": - """Get MySQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - """ - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: """Create both sessions and events tables if they don't exist.""" async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -506,71 +444,253 @@ async def delete_idle_sessions(self, updated_before: "datetime") -> int: return 0 raise + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = %s" -def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": - """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. + try: + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (app_name,)) + row = await cursor.fetchone() + return from_json(row[0]) if row is not None and isinstance(row[0], str) else (row[0] if row else None) + except pymysql.err.ProgrammingError as e: + if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise - Args: - column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE". + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = %s AND user_id = %s" - Returns: - Tuple of (column_definition, foreign_key_constraint). - """ - references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) - if not references_match: - return (column_ddl.strip(), "") + try: + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (app_name, user_id)) + row = await cursor.fetchone() + return from_json(row[0]) if row is not None and isinstance(row[0], str) else (row[0] if row else None) + except pymysql.err.ProgrammingError as e: + if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise - col_def = column_ddl[: references_match.start()].strip() - fk_clause = references_match.group(1).strip() - col_name = col_def.split()[0] - fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" - return (col_def, fk_constraint) + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (app_name, to_json(state))) + await conn.commit() -class AiomysqlADKMemoryStore(BaseAsyncADKMemoryStore["AiomysqlConfig"]): - """MySQL/MariaDB ADK memory store using aiomysql driver.""" + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ - __slots__ = () + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (app_name, user_id, to_json(state))) + await conn.commit() - def __init__(self, config: "AiomysqlConfig") -> None: - """Initialize aiomysql memory store.""" - super().__init__(config) + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE `key` = %s" - async def _get_create_memory_table_sql(self) -> str: - """Get MySQL CREATE TABLE SQL for memory entries.""" - owner_id_line = "" + try: + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (key,)) + row = await cursor.fetchone() + return row[0] if row is not None else None + except pymysql.err.ProgrammingError as e: + if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + sql = f""" + INSERT INTO {self._metadata_table} (`key`, value) + VALUES (%s, %s) + ON DUPLICATE KEY UPDATE value = VALUES(value) + """ + + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (key, value)) + await conn.commit() + + def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": + """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. + + MySQL ignores inline REFERENCES syntax in column definitions. + This method extracts the column definition and creates a separate + FOREIGN KEY constraint. + + Args: + column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" + + Returns: + Tuple of (column_definition, foreign_key_constraint) + """ + references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) + + if not references_match: + return (column_ddl.strip(), "") + + col_def = column_ddl[: references_match.start()].strip() + fk_clause = references_match.group(1).strip() + col_name = col_def.split()[0] + fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" + + return (col_def, fk_constraint) + + async def _get_create_sessions_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for sessions. + + Returns: + SQL statement to create adk_session table with indexes. + """ + owner_id_col = "" fk_constraint = "" + if self._owner_id_column_ddl: - col_def, fk_def = _parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_line = f",\n {col_def}" + col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl) + owner_id_col = f"{col_def}," if fk_def: fk_constraint = f",\n {fk_def}" - fts_index = "" - if self._use_fts: - fts_index = f",\n FULLTEXT INDEX idx_{self._memory_table}_fts (content_text)" - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( + CREATE TABLE IF NOT EXISTS {self._session_table} ( id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, app_name VARCHAR(128) NOT NULL, user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, + {owner_id_col} + state JSON NOT NULL, + create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + INDEX idx_{self._session_table}_app_user (app_name, user_id), + INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + async def _get_create_events_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for events. + + Returns: + SQL statement to create adk_event table with indexes. + + Notes: + Post clean-break schema: 5 columns only. + - session_id, invocation_id, author: indexed scalars + - timestamp: microsecond-precision TIMESTAMP + - event_data: full Event as native JSON + """ + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(128) NOT NULL, timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content_json JSON NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSON, - inserted_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - INDEX idx_{self._memory_table}_app_user_time (app_name, user_id, timestamp), - INDEX idx_{self._memory_table}_session (session_id){fts_index}{fk_constraint} + event_data JSON NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, + INDEX idx_{self._events_table}_session (session_id, timestamp ASC) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """ - def _get_drop_memory_table_sql(self) -> "list[str]": - """Get MySQL DROP TABLE SQL statements.""" - return [f"DROP TABLE IF EXISTS {self._memory_table}"] + async def _get_create_app_states_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for app-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + async def _get_create_user_states_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for user-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + async def _get_create_metadata_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for ADK internal metadata.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + `key` VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + async def _get_seed_metadata_sql(self) -> str: + """Get MySQL SQL that seeds the ADK schema version metadata row.""" + return f""" + INSERT IGNORE INTO {self._metadata_table} (`key`, value) + VALUES ('schema_version', '1') + """ + + def _get_drop_app_states_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for app-scoped state.""" + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for user-scoped state.""" + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for ADK internal metadata.""" + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + """Get MySQL DROP TABLE SQL statements. + + Returns: + List of SQL statements to drop tables and indexes. + """ + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + + +class AiomysqlADKMemoryStore(BaseAsyncADKMemoryStore["AiomysqlConfig"]): + """MySQL/MariaDB ADK memory store using aiomysql driver.""" + + __slots__ = () + + def __init__(self, config: "AiomysqlConfig") -> None: + """Initialize aiomysql memory store.""" + super().__init__(config) async def create_tables(self) -> None: """Create the memory table and indexes if they don't exist.""" @@ -716,3 +836,39 @@ async def delete_entries_older_than(self, days: int) -> int: await cursor.execute(sql, (days,)) await conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + + async def _get_create_memory_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for memory entries.""" + owner_id_line = "" + fk_constraint = "" + if self._owner_id_column_ddl: + col_def, fk_def = _parse_owner_id_column_for_mysql(self._owner_id_column_ddl) + owner_id_line = f",\n {col_def}" + if fk_def: + fk_constraint = f",\n {fk_def}" + + fts_index = "" + if self._use_fts: + fts_index = f",\n FULLTEXT INDEX idx_{self._memory_table}_fts (content_text)" + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + content_json JSON NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSON, + inserted_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + INDEX idx_{self._memory_table}_app_user_time (app_name, user_id, timestamp), + INDEX idx_{self._memory_table}_session (session_id){fts_index}{fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + """Get MySQL DROP TABLE SQL statements.""" + return [f"DROP TABLE IF EXISTS {self._memory_table}"] diff --git a/sqlspec/adapters/aiosqlite/adk/store.py b/sqlspec/adapters/aiosqlite/adk/store.py index e95bce977..de182d9c1 100644 --- a/sqlspec/adapters/aiosqlite/adk/store.py +++ b/sqlspec/adapters/aiosqlite/adk/store.py @@ -105,104 +105,21 @@ def __init__(self, config: "AiosqliteConfig") -> None: Notes: Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") + - session_table: Sessions table name (default: "adk_session") + - events_table: Events table name (default: "adk_event") """ super().__init__(config) - async def _apply_pragmas(self, connection: Any) -> None: - """Apply PRAGMA optimization profile for this connection. - - Args: - connection: Aiosqlite connection. - - Notes: - Enables foreign keys and applies performance PRAGMAs. - For file-based databases, adds cache_size, mmap_size, - and journal_size_limit optimizations. - """ - await connection.execute("PRAGMA foreign_keys = ON") - await connection.execute("PRAGMA cache_size = -64000") - await connection.execute("PRAGMA mmap_size = 30000000") - await connection.execute("PRAGMA journal_size_limit = 67108864") - - async def _get_create_sessions_table_sql(self) -> str: - """Get SQLite CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - - Notes: - - TEXT for IDs, names, and JSON state - - REAL for Julian Day timestamps - - Optional owner ID column for multi-tenant scenarios - - Composite index on (app_name, user_id) - - Index on update_time DESC for recent session queries - """ - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id TEXT PRIMARY KEY, - app_name TEXT NOT NULL, - user_id TEXT NOT NULL{owner_id_line}, - state TEXT NOT NULL DEFAULT '{{}}', - create_time REAL NOT NULL, - update_time REAL NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id); - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC); - """ - - async def _get_create_events_table_sql(self) -> str: - """Get SQLite CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - - Notes: - - TEXT for IDs and indexed scalars - - TEXT for full event JSON (event_data) - - REAL for Julian Day timestamps - - Foreign key to sessions with CASCADE delete - - Index on (session_id, timestamp ASC) - """ - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - invocation_id TEXT, - author TEXT, - timestamp REAL NOT NULL, - event_data TEXT NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC); - """ - - def _get_drop_tables_sql(self) -> "list[str]": - """Get SQLite DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - - Notes: - Order matters: drop events table (child) before sessions (parent). - SQLite automatically drops indexes when dropping tables. - """ - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: """Create both sessions and events tables if they don't exist.""" async with self._config.provide_session() as driver: await self._apply_pragmas(driver.connection) await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -573,6 +490,239 @@ async def delete_idle_sessions(self, updated_before: datetime) -> int: await conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = ?" + + try: + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + cursor = await conn.execute(sql, (app_name,)) + row = await cursor.fetchone() + return from_json(row[0]) if row is not None and row[0] else None + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return None + raise + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = ? AND user_id = ?" + + try: + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + cursor = await conn.execute(sql, (app_name, user_id)) + row = await cursor.fetchone() + return from_json(row[0]) if row is not None and row[0] else None + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return None + raise + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + ON CONFLICT(app_name) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + await conn.execute(sql, (app_name, to_json(state), _datetime_to_julian(datetime.now(timezone.utc)))) + await conn.commit() + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + ON CONFLICT(app_name, user_id) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + await conn.execute( + sql, (app_name, user_id, to_json(state), _datetime_to_julian(datetime.now(timezone.utc))) + ) + await conn.commit() + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE key = ?" + + try: + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + cursor = await conn.execute(sql, (key,)) + row = await cursor.fetchone() + return row[0] if row is not None else None + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return None + raise + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + sql = f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES (?, ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """ + + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + await conn.execute(sql, (key, value)) + await conn.commit() + + async def _apply_pragmas(self, connection: Any) -> None: + """Apply PRAGMA optimization profile for this connection. + + Args: + connection: Aiosqlite connection. + + Notes: + Enables foreign keys and applies performance PRAGMAs. + For file-based databases, adds cache_size, mmap_size, + and journal_size_limit optimizations. + """ + await connection.execute("PRAGMA foreign_keys = ON") + await connection.execute("PRAGMA cache_size = -64000") + await connection.execute("PRAGMA mmap_size = 30000000") + await connection.execute("PRAGMA journal_size_limit = 67108864") + + async def _get_create_sessions_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for sessions. + + Returns: + SQL statement to create adk_session table with indexes. + + Notes: + - TEXT for IDs, names, and JSON state + - REAL for Julian Day timestamps + - Optional owner ID column for multi-tenant scenarios + - Composite index on (app_name, user_id) + - Index on update_time DESC for recent session queries + """ + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + return f""" + CREATE TABLE IF NOT EXISTS {self._session_table} ( + id TEXT PRIMARY KEY, + app_name TEXT NOT NULL, + user_id TEXT NOT NULL{owner_id_line}, + state TEXT NOT NULL DEFAULT '{{}}', + create_time REAL NOT NULL, + update_time REAL NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id); + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC); + """ + + async def _get_create_events_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for events. + + Returns: + SQL statement to create adk_event table with indexes. + + Notes: + - TEXT for IDs and indexed scalars + - TEXT for full event JSON (event_data) + - REAL for Julian Day timestamps + - Foreign key to sessions with CASCADE delete + - Index on (session_id, timestamp ASC) + """ + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + invocation_id TEXT, + author TEXT, + timestamp REAL NOT NULL, + event_data TEXT NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session + ON {self._events_table}(session_id, timestamp ASC); + """ + + async def _get_create_app_states_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for app-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name TEXT PRIMARY KEY, + state TEXT NOT NULL DEFAULT '{{}}', + update_time REAL NOT NULL + ); + """ + + async def _get_create_user_states_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for user-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name TEXT NOT NULL, + user_id TEXT NOT NULL, + state TEXT NOT NULL DEFAULT '{{}}', + update_time REAL NOT NULL, + PRIMARY KEY (app_name, user_id) + ); + """ + + async def _get_create_metadata_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for ADK internal metadata.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + """ + + async def _get_seed_metadata_sql(self) -> str: + """Get SQLite SQL that seeds the ADK schema version metadata row.""" + return f""" + INSERT OR IGNORE INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + """ + + def _get_drop_app_states_table_sql(self) -> str: + """Get SQLite DROP TABLE SQL for app-scoped state.""" + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + """Get SQLite DROP TABLE SQL for user-scoped state.""" + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + """Get SQLite DROP TABLE SQL for ADK internal metadata.""" + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + """Get SQLite DROP TABLE SQL statements. + + Returns: + List of SQL statements to drop tables and indexes. + + Notes: + Order matters: drop events table (child) before sessions (parent). + SQLite automatically drops indexes when dropping tables. + """ + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + class AiosqliteADKMemoryStore(BaseAsyncADKMemoryStore["AiosqliteConfig"]): """Aiosqlite ADK memory store using asynchronous SQLite driver. @@ -633,84 +783,6 @@ def __init__(self, config: "AiosqliteConfig") -> None: """ super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: - """Get SQLite CREATE TABLE SQL for memory entries. - - Returns: - SQL statement to create memory table with indexes. - - Notes: - - TEXT for IDs, names, and JSON content - - REAL for Julian Day timestamps - - UNIQUE constraint on event_id for deduplication - - Composite index on (app_name, user_id, timestamp DESC) - - Optional owner ID column for multi-tenancy - - Optional FTS5 virtual table for full-text search - """ - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - fts_table = "" - if self._use_fts: - fts_table = f""" - CREATE VIRTUAL TABLE IF NOT EXISTS {self._memory_table}_fts USING fts5( - content_text, - content={self._memory_table}, - content_rowid=rowid - ); - - CREATE TRIGGER IF NOT EXISTS {self._memory_table}_ai AFTER INSERT ON {self._memory_table} BEGIN - INSERT INTO {self._memory_table}_fts(rowid, content_text) VALUES (new.rowid, new.content_text); - END; - - CREATE TRIGGER IF NOT EXISTS {self._memory_table}_ad AFTER DELETE ON {self._memory_table} BEGIN - INSERT INTO {self._memory_table}_fts({self._memory_table}_fts, rowid, content_text) - VALUES('delete', old.rowid, old.content_text); - END; - - CREATE TRIGGER IF NOT EXISTS {self._memory_table}_au AFTER UPDATE ON {self._memory_table} BEGIN - INSERT INTO {self._memory_table}_fts({self._memory_table}_fts, rowid, content_text) - VALUES('delete', old.rowid, old.content_text); - INSERT INTO {self._memory_table}_fts(rowid, content_text) VALUES (new.rowid, new.content_text); - END; - """ - - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( - id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - app_name TEXT NOT NULL, - user_id TEXT NOT NULL, - event_id TEXT NOT NULL UNIQUE, - author TEXT{owner_id_line}, - timestamp REAL NOT NULL, - content_json TEXT NOT NULL, - content_text TEXT NOT NULL, - metadata_json TEXT, - inserted_at REAL NOT NULL - ); - - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time - ON {self._memory_table}(app_name, user_id, timestamp DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session - ON {self._memory_table}(session_id); - {fts_table} - """ - - def _get_drop_memory_table_sql(self) -> "list[str]": - """Get SQLite DROP TABLE SQL statements.""" - statements = [f"DROP TABLE IF EXISTS {self._memory_table}"] - if self._use_fts: - statements.extend([ - f"DROP TABLE IF EXISTS {self._memory_table}_fts", - f"DROP TRIGGER IF EXISTS {self._memory_table}_ai", - f"DROP TRIGGER IF EXISTS {self._memory_table}_ad", - f"DROP TRIGGER IF EXISTS {self._memory_table}_au", - ]) - return statements - async def create_tables(self) -> None: """Create the memory table and indexes if they don't exist. @@ -855,3 +927,81 @@ async def delete_entries_older_than(self, days: int) -> int: cursor = await conn.execute(sql, (cutoff,)) await conn.commit() return cursor.rowcount + + async def _get_create_memory_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for memory entries. + + Returns: + SQL statement to create memory table with indexes. + + Notes: + - TEXT for IDs, names, and JSON content + - REAL for Julian Day timestamps + - UNIQUE constraint on event_id for deduplication + - Composite index on (app_name, user_id, timestamp DESC) + - Optional owner ID column for multi-tenancy + - Optional FTS5 virtual table for full-text search + """ + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + fts_table = "" + if self._use_fts: + fts_table = f""" + CREATE VIRTUAL TABLE IF NOT EXISTS {self._memory_table}_fts USING fts5( + content_text, + content={self._memory_table}, + content_rowid=rowid + ); + + CREATE TRIGGER IF NOT EXISTS {self._memory_table}_ai AFTER INSERT ON {self._memory_table} BEGIN + INSERT INTO {self._memory_table}_fts(rowid, content_text) VALUES (new.rowid, new.content_text); + END; + + CREATE TRIGGER IF NOT EXISTS {self._memory_table}_ad AFTER DELETE ON {self._memory_table} BEGIN + INSERT INTO {self._memory_table}_fts({self._memory_table}_fts, rowid, content_text) + VALUES('delete', old.rowid, old.content_text); + END; + + CREATE TRIGGER IF NOT EXISTS {self._memory_table}_au AFTER UPDATE ON {self._memory_table} BEGIN + INSERT INTO {self._memory_table}_fts({self._memory_table}_fts, rowid, content_text) + VALUES('delete', old.rowid, old.content_text); + INSERT INTO {self._memory_table}_fts(rowid, content_text) VALUES (new.rowid, new.content_text); + END; + """ + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + app_name TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL UNIQUE, + author TEXT{owner_id_line}, + timestamp REAL NOT NULL, + content_json TEXT NOT NULL, + content_text TEXT NOT NULL, + metadata_json TEXT, + inserted_at REAL NOT NULL + ); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time + ON {self._memory_table}(app_name, user_id, timestamp DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session + ON {self._memory_table}(session_id); + {fts_table} + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + """Get SQLite DROP TABLE SQL statements.""" + statements = [f"DROP TABLE IF EXISTS {self._memory_table}"] + if self._use_fts: + statements.extend([ + f"DROP TABLE IF EXISTS {self._memory_table}_fts", + f"DROP TRIGGER IF EXISTS {self._memory_table}_ai", + f"DROP TRIGGER IF EXISTS {self._memory_table}_ad", + f"DROP TRIGGER IF EXISTS {self._memory_table}_au", + ]) + return statements diff --git a/sqlspec/adapters/asyncmy/adk/store.py b/sqlspec/adapters/asyncmy/adk/store.py index b843365ad..69709ae63 100644 --- a/sqlspec/adapters/asyncmy/adk/store.py +++ b/sqlspec/adapters/asyncmy/adk/store.py @@ -21,6 +21,26 @@ MYSQL_TABLE_NOT_FOUND_ERROR: Final = 1146 +def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": + """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. + + Args: + column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE". + + Returns: + Tuple of (column_definition, foreign_key_constraint). + """ + references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) + if not references_match: + return (column_ddl.strip(), "") + + col_def = column_ddl[: references_match.start()].strip() + fk_clause = references_match.group(1).strip() + col_name = col_def.split()[0] + fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" + return (col_def, fk_constraint) + + class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]): """MySQL/MariaDB ADK store using AsyncMy driver. @@ -50,97 +70,15 @@ def __init__(self, config: "AsyncmyConfig") -> None: """ super().__init__(config) - def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": - """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. - - MySQL ignores inline REFERENCES syntax in column definitions. - This method extracts the column definition and creates a separate - FOREIGN KEY constraint. - - Args: - column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - - Returns: - Tuple of (column_definition, foreign_key_constraint) - """ - references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) - - if not references_match: - return (column_ddl.strip(), "") - - col_def = column_ddl[: references_match.start()].strip() - fk_clause = references_match.group(1).strip() - col_name = col_def.split()[0] - fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" - - return (col_def, fk_constraint) - - async def _get_create_sessions_table_sql(self) -> str: - """Get MySQL CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - """ - owner_id_col = "" - fk_constraint = "" - - if self._owner_id_column_ddl: - col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_col = f"{col_def}," - if fk_def: - fk_constraint = f",\n {fk_def}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - {owner_id_col} - state JSON NOT NULL, - create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), - INDEX idx_{self._session_table}_app_user (app_name, user_id), - INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - async def _get_create_events_table_sql(self) -> str: - """Get MySQL CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - - Notes: - Post clean-break schema: 5 columns only. - - session_id, invocation_id, author: indexed scalars - - timestamp: microsecond-precision TIMESTAMP - - event_data: full Event as native JSON - """ - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(128) NOT NULL, - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - event_data JSON NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, - INDEX idx_{self._events_table}_session (session_id, timestamp ASC) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - def _get_drop_tables_sql(self) -> "list[str]": - """Get MySQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - """ - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: """Create both sessions and events tables if they don't exist.""" async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -475,71 +413,235 @@ async def delete_idle_sessions(self, updated_before: "datetime") -> int: return 0 raise + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = %s" -def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": - """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. + try: + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (app_name,)) + row = await cursor.fetchone() + return from_json(row[0]) if row is not None and isinstance(row[0], str) else (row[0] if row else None) + except asyncmy.errors.ProgrammingError as e: # pyright: ignore[reportAttributeAccessIssue] + if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise - Args: - column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE". + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = %s AND user_id = %s" - Returns: - Tuple of (column_definition, foreign_key_constraint). - """ - references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) - if not references_match: - return (column_ddl.strip(), "") + try: + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (app_name, user_id)) + row = await cursor.fetchone() + return from_json(row[0]) if row is not None and isinstance(row[0], str) else (row[0] if row else None) + except asyncmy.errors.ProgrammingError as e: # pyright: ignore[reportAttributeAccessIssue] + if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise - col_def = column_ddl[: references_match.start()].strip() - fk_clause = references_match.group(1).strip() - col_name = col_def.split()[0] - fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" - return (col_def, fk_constraint) + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (app_name, to_json(state))) + await conn.commit() -class AsyncmyADKMemoryStore(BaseAsyncADKMemoryStore["AsyncmyConfig"]): - """MySQL/MariaDB ADK memory store using AsyncMy driver.""" + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ - __slots__ = () + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (app_name, user_id, to_json(state))) + await conn.commit() - def __init__(self, config: "AsyncmyConfig") -> None: - """Initialize AsyncMy memory store.""" - super().__init__(config) + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE `key` = %s" - async def _get_create_memory_table_sql(self) -> str: - """Get MySQL CREATE TABLE SQL for memory entries.""" - owner_id_line = "" + try: + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (key,)) + row = await cursor.fetchone() + return row[0] if row is not None else None + except asyncmy.errors.ProgrammingError as e: # pyright: ignore[reportAttributeAccessIssue] + if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + sql = f""" + INSERT INTO {self._metadata_table} (`key`, value) + VALUES (%s, %s) + ON DUPLICATE KEY UPDATE value = VALUES(value) + """ + + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (key, value)) + await conn.commit() + + def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": + """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. + + MySQL ignores inline REFERENCES syntax in column definitions. + This method extracts the column definition and creates a separate + FOREIGN KEY constraint. + + Args: + column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" + + Returns: + Tuple of (column_definition, foreign_key_constraint) + """ + references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) + + if not references_match: + return (column_ddl.strip(), "") + + col_def = column_ddl[: references_match.start()].strip() + fk_clause = references_match.group(1).strip() + col_name = col_def.split()[0] + fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" + + return (col_def, fk_constraint) + + async def _get_create_sessions_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for sessions. + + Returns: + SQL statement to create adk_session table with indexes. + """ + owner_id_col = "" fk_constraint = "" + if self._owner_id_column_ddl: - col_def, fk_def = _parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_line = f",\n {col_def}" + col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl) + owner_id_col = f"{col_def}," if fk_def: fk_constraint = f",\n {fk_def}" - fts_index = "" - if self._use_fts: - fts_index = f",\n FULLTEXT INDEX idx_{self._memory_table}_fts (content_text)" - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( + CREATE TABLE IF NOT EXISTS {self._session_table} ( id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, app_name VARCHAR(128) NOT NULL, user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, + {owner_id_col} + state JSON NOT NULL, + create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + INDEX idx_{self._session_table}_app_user (app_name, user_id), + INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + async def _get_create_events_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for events. + + Returns: + SQL statement to create adk_event table with indexes. + + Notes: + Post clean-break schema: 5 columns only. + - session_id, invocation_id, author: indexed scalars + - timestamp: microsecond-precision TIMESTAMP + - event_data: full Event as native JSON + """ + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(128) NOT NULL, timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content_json JSON NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSON, - inserted_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - INDEX idx_{self._memory_table}_app_user_time (app_name, user_id, timestamp), - INDEX idx_{self._memory_table}_session (session_id){fts_index}{fk_constraint} + event_data JSON NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, + INDEX idx_{self._events_table}_session (session_id, timestamp ASC) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """ - def _get_drop_memory_table_sql(self) -> "list[str]": - """Get MySQL DROP TABLE SQL statements.""" - return [f"DROP TABLE IF EXISTS {self._memory_table}"] + async def _get_create_app_states_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for app-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + async def _get_create_user_states_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for user-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + async def _get_create_metadata_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for ADK internal metadata.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + `key` VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + async def _get_seed_metadata_sql(self) -> str: + """Get MySQL SQL that seeds the ADK schema version metadata row.""" + return f""" + INSERT IGNORE INTO {self._metadata_table} (`key`, value) + VALUES ('schema_version', '1') + """ + + def _get_drop_app_states_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for app-scoped state.""" + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for user-scoped state.""" + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for ADK internal metadata.""" + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + """Get MySQL DROP TABLE SQL statements. + + Returns: + List of SQL statements to drop tables and indexes. + """ + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + + +class AsyncmyADKMemoryStore(BaseAsyncADKMemoryStore["AsyncmyConfig"]): + """MySQL/MariaDB ADK memory store using AsyncMy driver.""" + + __slots__ = () + + def __init__(self, config: "AsyncmyConfig") -> None: + """Initialize AsyncMy memory store.""" + super().__init__(config) async def create_tables(self) -> None: """Create the memory table and indexes if they don't exist.""" @@ -676,3 +778,39 @@ async def delete_entries_older_than(self, days: int) -> int: await cursor.execute(sql, (days,)) await conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + + async def _get_create_memory_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for memory entries.""" + owner_id_line = "" + fk_constraint = "" + if self._owner_id_column_ddl: + col_def, fk_def = _parse_owner_id_column_for_mysql(self._owner_id_column_ddl) + owner_id_line = f",\n {col_def}" + if fk_def: + fk_constraint = f",\n {fk_def}" + + fts_index = "" + if self._use_fts: + fts_index = f",\n FULLTEXT INDEX idx_{self._memory_table}_fts (content_text)" + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + content_json JSON NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSON, + inserted_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + INDEX idx_{self._memory_table}_app_user_time (app_name, user_id, timestamp), + INDEX idx_{self._memory_table}_session (session_id){fts_index}{fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + """Get MySQL DROP TABLE SQL statements.""" + return [f"DROP TABLE IF EXISTS {self._memory_table}"] diff --git a/sqlspec/adapters/asyncpg/adk/store.py b/sqlspec/adapters/asyncpg/adk/store.py index 06775350f..d26e174ee 100644 --- a/sqlspec/adapters/asyncpg/adk/store.py +++ b/sqlspec/adapters/asyncpg/adk/store.py @@ -46,54 +46,14 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]): def __init__(self, config: AsyncConfigT) -> None: super().__init__(config) - async def _get_create_sessions_table_sql(self) -> str: - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL{owner_id_line}, - state JSONB NOT NULL DEFAULT '{{}}'::jsonb, - create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ) WITH (fillfactor = 80); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state - ON {self._session_table} USING GIN (state) - WHERE state != '{{}}'::jsonb; - """ - - async def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_data JSONB NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ) WITH (fillfactor = 80); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC); - """ - - def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -317,6 +277,164 @@ async def delete_idle_sessions(self, updated_before: "datetime") -> int: except asyncpg.exceptions.UndefinedTableError: return 0 + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = $1" + + try: + async with self._config.provide_connection() as conn: + row = await conn.fetchrow(sql, app_name) + return row["state"] if row is not None else None + except asyncpg.exceptions.UndefinedTableError: + return None + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = $1 AND user_id = $2" + + try: + async with self._config.provide_connection() as conn: + row = await conn.fetchrow(sql, app_name, user_id) + return row["state"] if row is not None else None + except asyncpg.exceptions.UndefinedTableError: + return None + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES ($1, $2, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES ($1, $2, $3, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE key = $1" + + try: + async with self._config.provide_connection() as conn: + row = await conn.fetchrow(sql, key) + return row["value"] if row is not None else None + except asyncpg.exceptions.UndefinedTableError: + return None + + async def set_metadata(self, key: str, value: str) -> None: + sql = f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ($1, $2) + ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, key, value) + + async def _get_create_sessions_table_sql(self) -> str: + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + return f""" + CREATE TABLE IF NOT EXISTS {self._session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL{owner_id_line}, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; + """ + + async def _get_create_events_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_data JSONB NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session + ON {self._events_table}(session_id, timestamp ASC); + """ + + async def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + """ + + async def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ) WITH (fillfactor = 80); + """ + + async def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ + + async def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING + """ + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + class AsyncpgADKMemoryStore(BaseAsyncADKMemoryStore["AsyncpgConfig"]): """PostgreSQL ADK memory store using asyncpg driver. @@ -339,44 +457,6 @@ class AsyncpgADKMemoryStore(BaseAsyncADKMemoryStore["AsyncpgConfig"]): def __init__(self, config: "AsyncpgConfig") -> None: super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - fts_index = "" - if self._use_fts: - fts_index = f""" - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts - ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); - """ - - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( - id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, - timestamp TIMESTAMPTZ NOT NULL, - content_json JSONB NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSONB, - inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ); - - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time - ON {self._memory_table}(app_name, user_id, timestamp DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session - ON {self._memory_table}(session_id); - {fts_index} - """ - - def _get_drop_memory_table_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._memory_table}"] - async def create_tables(self) -> None: if not self._enabled: return @@ -511,3 +591,41 @@ async def delete_entries_older_than(self, days: int) -> int: return int(result.split(" ")[1]) except (IndexError, ValueError): return 0 + + async def _get_create_memory_table_sql(self) -> str: + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMPTZ NOT NULL, + content_json JSONB NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSONB, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time + ON {self._memory_table}(app_name, user_id, timestamp DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session + ON {self._memory_table}(session_id); + {fts_index} + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + return [f"DROP TABLE IF EXISTS {self._memory_table}"] diff --git a/sqlspec/adapters/cockroach_asyncpg/adk/store.py b/sqlspec/adapters/cockroach_asyncpg/adk/store.py index 942c539d6..073497182 100644 --- a/sqlspec/adapters/cockroach_asyncpg/adk/store.py +++ b/sqlspec/adapters/cockroach_asyncpg/adk/store.py @@ -40,57 +40,14 @@ class CockroachAsyncpgADKStore(BaseAsyncADKStore["CockroachAsyncpgConfig"]): def __init__(self, config: "CockroachAsyncpgConfig") -> None: super().__init__(config) - async def _get_create_sessions_table_sql(self) -> str: - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL{owner_id_line}, - state JSONB NOT NULL DEFAULT '{{}}'::jsonb, - create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state - ON {self._session_table} USING GIN (state) - WHERE state != '{{}}'::jsonb; - """ - - async def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_data JSONB NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_data - ON {self._events_table} USING GIN (event_data); - """ - - def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -319,52 +276,168 @@ async def delete_idle_sessions(self, updated_before: "datetime") -> int: except asyncpg.exceptions.UndefinedTableError: return 0 + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = $1" -class CockroachAsyncpgADKMemoryStore(BaseAsyncADKMemoryStore["CockroachAsyncpgConfig"]): - """CockroachDB ADK memory store using asyncpg driver.""" + try: + async with self._config.provide_connection() as conn: + row = await conn.fetchrow(sql, app_name) + return row["state"] if row is not None else None + except asyncpg.exceptions.UndefinedTableError: + return None - __slots__ = () + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = $1 AND user_id = $2" - def __init__(self, config: "CockroachAsyncpgConfig") -> None: - super().__init__(config) + try: + async with self._config.provide_connection() as conn: + row = await conn.fetchrow(sql, app_name, user_id) + return row["state"] if row is not None else None + except asyncpg.exceptions.UndefinedTableError: + return None - async def _get_create_memory_table_sql(self) -> str: + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + UPSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES ($1, $2, CURRENT_TIMESTAMP) + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + UPSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES ($1, $2, $3, CURRENT_TIMESTAMP) + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE key = $1" + + try: + async with self._config.provide_connection() as conn: + row = await conn.fetchrow(sql, key) + return row["value"] if row is not None else None + except asyncpg.exceptions.UndefinedTableError: + return None + + async def set_metadata(self, key: str, value: str) -> None: + sql = f""" + UPSERT INTO {self._metadata_table} (key, value) + VALUES ($1, $2) + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, key, value) + + async def _get_create_sessions_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" - fts_index = "" - if self._use_fts: - fts_index = f""" - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts - ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); - """ - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( + CREATE TABLE IF NOT EXISTS {self._session_table} ( id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL{owner_id_line}, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; + """ + + async def _get_create_events_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_data JSONB NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE + ); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session + ON {self._events_table}(session_id, timestamp ASC); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_data + ON {self._events_table} USING GIN (event_data); + """ + + async def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + """ + + async def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( app_name VARCHAR(128) NOT NULL, user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, - timestamp TIMESTAMPTZ NOT NULL, - content_json JSONB NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSONB, - inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) ); + """ - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time - ON {self._memory_table}(app_name, user_id, timestamp DESC); + async def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session - ON {self._memory_table}(session_id); - {fts_index} + async def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING """ - def _get_drop_memory_table_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._memory_table}"] + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + + +class CockroachAsyncpgADKMemoryStore(BaseAsyncADKMemoryStore["CockroachAsyncpgConfig"]): + """CockroachDB ADK memory store using asyncpg driver.""" + + __slots__ = () + + def __init__(self, config: "CockroachAsyncpgConfig") -> None: + super().__init__(config) async def create_tables(self) -> None: if not self._enabled: @@ -493,3 +566,41 @@ async def delete_entries_older_than(self, days: int) -> int: async with self._config.provide_connection() as conn: result = await conn.execute(sql) return int(result.split()[-1]) if result else 0 + + async def _get_create_memory_table_sql(self) -> str: + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMPTZ NOT NULL, + content_json JSONB NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSONB, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time + ON {self._memory_table}(app_name, user_id, timestamp DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session + ON {self._memory_table}(session_id); + {fts_index} + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + return [f"DROP TABLE IF EXISTS {self._memory_table}"] diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py index bcab87774..97436c80d 100644 --- a/sqlspec/adapters/cockroach_psycopg/adk/store.py +++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py @@ -81,57 +81,14 @@ class CockroachPsycopgAsyncADKStore(BaseAsyncADKStore["CockroachPsycopgAsyncConf def __init__(self, config: "CockroachPsycopgAsyncConfig") -> None: super().__init__(config) - async def _get_create_sessions_table_sql(self) -> str: - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL{owner_id_line}, - state JSONB NOT NULL DEFAULT '{{}}'::jsonb, - create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state - ON {self._session_table} USING GIN (state) - WHERE state != '{{}}'::jsonb; - """ - - async def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_data JSONB NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_data - ON {self._events_table} USING GIN (event_data); - """ - - def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -384,6 +341,166 @@ async def delete_idle_sessions(self, updated_before: "datetime") -> int: except errors.UndefinedTable: return 0 + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(sql.encode(), (app_name,)) + row = await cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = %s AND user_id = %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(sql.encode(), (app_name, user_id)) + row = await cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + UPSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + """ + + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(sql.encode(), (app_name, Jsonb(state))) + await conn.commit() + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + UPSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + """ + + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(sql.encode(), (app_name, user_id, Jsonb(state))) + await conn.commit() + + async def get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE key = %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(sql.encode(), (key,)) + row = await cur.fetchone() + return row["value"] if row is not None else None + except errors.UndefinedTable: + return None + + async def set_metadata(self, key: str, value: str) -> None: + sql = f""" + UPSERT INTO {self._metadata_table} (key, value) + VALUES (%s, %s) + """ + + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(sql.encode(), (key, value)) + await conn.commit() + + async def _get_create_sessions_table_sql(self) -> str: + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + return f""" + CREATE TABLE IF NOT EXISTS {self._session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL{owner_id_line}, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; + """ + + async def _get_create_events_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_data JSONB NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE + ); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session + ON {self._events_table}(session_id, timestamp ASC); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_data + ON {self._events_table} USING GIN (event_data); + """ + + async def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + """ + + async def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ); + """ + + async def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ + + async def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING + """ + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + class CockroachPsycopgSyncADKStore(BaseAsyncADKStore["CockroachPsycopgSyncConfig"]): """CockroachDB ADK store using psycopg sync driver. @@ -404,6 +521,82 @@ class CockroachPsycopgSyncADKStore(BaseAsyncADKStore["CockroachPsycopgSyncConfig def __init__(self, config: "CockroachPsycopgSyncConfig") -> None: super().__init__(config) + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id, renew_for) + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> SessionRecord: + """Atomically append an event and update the session's durable state.""" + return await async_(self._append_event_and_update_state)(event_record, session_id, state) + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return await async_(self._get_app_state)(app_name) + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return await async_(self._get_user_state)(app_name, user_id) + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + await async_(self._upsert_app_state)(app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + await async_(self._upsert_user_state)(app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return await async_(self._get_metadata)(key) + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + await async_(self._set_metadata)(key, value) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + async def _get_create_sessions_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: @@ -448,17 +641,67 @@ async def _get_create_events_table_sql(self) -> str: ON {self._events_table} USING GIN (event_data); """ + async def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + """ + + async def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ); + """ + + async def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ + + async def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING + """ + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] def _create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(run_(self._get_create_sessions_table_sql)()) driver.execute_script(run_(self._get_create_events_table_sql)()) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(run_(self._get_create_app_states_table_sql)()) + driver.execute_script(run_(self._get_create_user_states_table_sql)()) + driver.execute_script(run_(self._get_create_metadata_table_sql)()) + driver.execute_script(run_(self._get_seed_metadata_sql)()) def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -489,12 +732,6 @@ def _create_session( raise RuntimeError(msg) return result - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": if renew_for is not None and self._calculate_expires_at(renew_for) is not None: sql = f""" @@ -529,12 +766,6 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No except errors.UndefinedTable: return None - async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None - ) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} @@ -546,10 +777,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non cur.execute(sql.encode(), (Jsonb(state), session_id)) conn.commit() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state.""" - await async_(self._update_session_state)(session_id, state) - def _delete_session(self, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE id = %s" @@ -557,10 +784,6 @@ def _delete_session(self, session_id: str) -> None: cur.execute(sql.encode(), (session_id,)) conn.commit() - async def delete_session(self, session_id: str) -> None: - """Delete session and associated events.""" - await async_(self._delete_session)(session_id) - def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: sql = f""" @@ -598,10 +821,6 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses except errors.UndefinedTable: return [] - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) - def _append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> SessionRecord: @@ -648,12 +867,6 @@ def _append_event_and_update_state( update_time=row["update_time"], ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _insert_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( @@ -716,12 +929,6 @@ def _get_events( except errors.UndefinedTable: return [] - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) - def _delete_expired_events(self, before: "datetime") -> int: sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" @@ -733,10 +940,6 @@ def _delete_expired_events(self, before: "datetime") -> int: except errors.UndefinedTable: return 0 - async def delete_expired_events(self, before: "datetime") -> int: - """Delete events older than the given timestamp.""" - return await async_(self._delete_expired_events)(before) - def _delete_idle_sessions(self, updated_before: "datetime") -> int: sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" @@ -748,18 +951,73 @@ def _delete_idle_sessions(self, updated_before: "datetime") -> int: except errors.UndefinedTable: return 0 - async def delete_idle_sessions(self, updated_before: "datetime") -> int: - """Delete sessions whose update_time predates the given threshold.""" - return await async_(self._delete_idle_sessions)(updated_before) + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = %s" + + try: + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(sql.encode(), (app_name,)) + row = cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = %s AND user_id = %s" + + try: + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(sql.encode(), (app_name, user_id)) + row = cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + UPSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + """ + + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(sql.encode(), (app_name, Jsonb(state))) + conn.commit() + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + UPSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + """ + + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(sql.encode(), (app_name, user_id, Jsonb(state))) + conn.commit() + + def _get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE key = %s" + + try: + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(sql.encode(), (key,)) + row = cur.fetchone() + return row["value"] if row is not None else None + except errors.UndefinedTable: + return None + + def _set_metadata(self, key: str, value: str) -> None: + sql = f""" + UPSERT INTO {self._metadata_table} (key, value) + VALUES (%s, %s) + """ + + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(sql.encode(), (key, value)) + conn.commit() def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) - class CockroachPsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsycopgAsyncConfig"]): """CockroachDB ADK memory store using psycopg async driver.""" @@ -769,44 +1027,6 @@ class CockroachPsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsyc def __init__(self, config: "CockroachPsycopgAsyncConfig") -> None: super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - fts_index = "" - if self._use_fts: - fts_index = f""" - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts - ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); - """ - - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( - id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, - timestamp TIMESTAMPTZ NOT NULL, - content_json JSONB NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSONB, - inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ); - - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time - ON {self._memory_table}(app_name, user_id, timestamp DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session - ON {self._memory_table}(session_id); - {fts_index} - """ - - def _get_drop_memory_table_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._memory_table}"] - async def create_tables(self) -> None: if not self._enabled: return @@ -925,6 +1145,44 @@ async def delete_entries_older_than(self, days: int) -> int: await conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + async def _get_create_memory_table_sql(self) -> str: + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMPTZ NOT NULL, + content_json JSONB NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSONB, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time + ON {self._memory_table}(app_name, user_id, timestamp DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session + ON {self._memory_table}(session_id); + {fts_index} + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + return [f"DROP TABLE IF EXISTS {self._memory_table}"] + class CockroachPsycopgSyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsycopgSyncConfig"]): """CockroachDB ADK memory store using psycopg sync driver.""" @@ -934,6 +1192,28 @@ class CockroachPsycopgSyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsyco def __init__(self, config: "CockroachPsycopgSyncConfig") -> None: super().__init__(config) + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + async def _get_create_memory_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: @@ -979,10 +1259,6 @@ def _create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(run_(self._get_create_memory_table_sql)()) - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -1026,10 +1302,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec inserted_count += cur.rowcount return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -1071,12 +1343,6 @@ def _search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _delete_entries_by_session(self, session_id: str) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -1088,10 +1354,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -1105,7 +1367,3 @@ def _delete_entries_older_than(self, days: int) -> int: cur.execute(sql.encode()) conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py index a216f9312..37e164c3d 100644 --- a/sqlspec/adapters/duckdb/adk/store.py +++ b/sqlspec/adapters/duckdb/adk/store.py @@ -86,17 +86,172 @@ def __init__(self, config: "DuckDBConfig") -> None: Notes: Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") + - session_table: Sessions table name (default: "adk_session") + - events_table: Events table name (default: "adk_event") - owner_id_column: Optional owner FK column DDL (default: None) """ super().__init__(config) + async def create_tables(self) -> None: + """Create both sessions and events tables if they don't exist.""" + await async_(self._create_tables)() + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session. + + Args: + session_id: Unique session identifier. + app_name: Application name. + user_id: User identifier. + state: Initial session state. + owner_id: Optional owner ID value for owner_id_column (if configured). + + Returns: + Created session record. + + Notes: + Uses current UTC timestamp for create_time and update_time. + State is JSON-serialized using SQLSpec serializers. + """ + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID. + + Args: + session_id: Session identifier. + renew_for: If positive, touch update_time while reading. + + Returns: + Session record or None if not found. + + Notes: + DuckDB returns datetime objects for TIMESTAMPTZ columns. + JSON is parsed from database storage. + """ + return await async_(self._get_session)(session_id, renew_for) + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state. + + Args: + session_id: Session identifier. + state: New state dictionary (replaces existing state). + + Notes: + This replaces the entire state dictionary. + Update time is automatically set to current UTC timestamp. + """ + await async_(self._update_session_state)(session_id, state) + + async def delete_session(self, session_id: str) -> None: + """Delete session and all associated events. + + Args: + session_id: Session identifier. + + Notes: + DuckDB doesn't support CASCADE in foreign keys, so we manually delete events first. + """ + await async_(self._delete_session)(session_id) + + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app, optionally filtered by user. + + Args: + app_name: Application name. + user_id: User identifier. If None, lists all sessions for the app. + + Returns: + List of session records ordered by update_time DESC. + + Notes: + Uses composite index on (app_name, user_id) when user_id is provided. + """ + return await async_(self._list_sessions)(app_name, user_id) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session. + + Args: + event_record: Event record with 5 keys (session_id, invocation_id, + author, timestamp, event_data). + """ + await async_(self._append_event)(event_record) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> SessionRecord: + """Atomically append an event and update the session's durable state. + + The event insert and state update succeed together or fail together + within a single DuckDB transaction; the updated SessionRecord is + returned via UPDATE...RETURNING. + + Args: + event_record: Event record to store (5-key shape). + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). + """ + return await async_(self._append_event_and_update_state)(event_record, session_id, state) + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session. + + Args: + session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. + + Returns: + List of event records ordered by timestamp ASC. + """ + return await async_(self._get_events)(session_id, after_timestamp, limit) + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return await async_(self._get_app_state)(app_name) + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return await async_(self._get_user_state)(app_name, user_id) + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + await async_(self._upsert_app_state)(app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + await async_(self._upsert_user_state)(app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return await async_(self._get_metadata)(key) + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + await async_(self._set_metadata)(key, value) + async def _get_create_sessions_table_sql(self) -> str: """Get DuckDB CREATE TABLE SQL for sessions. Returns: - SQL statement to create adk_sessions table with indexes. + SQL statement to create adk_session table with indexes. Notes: - VARCHAR for IDs and names @@ -128,7 +283,7 @@ async def _get_create_events_table_sql(self) -> str: """Get DuckDB CREATE TABLE SQL for events. Returns: - SQL statement to create adk_events table with indexes. + SQL statement to create adk_event table with indexes. Notes: - 5-column schema: session_id, invocation_id, author, timestamp, event_data @@ -150,6 +305,34 @@ async def _get_create_events_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ + async def _get_create_app_states_table_sql(self) -> str: + """Get DuckDB CREATE TABLE SQL for app-scoped state.""" + return self.__get_create_app_states_table_sql_sync() + + async def _get_create_user_states_table_sql(self) -> str: + """Get DuckDB CREATE TABLE SQL for user-scoped state.""" + return self.__get_create_user_states_table_sql_sync() + + async def _get_create_metadata_table_sql(self) -> str: + """Get DuckDB CREATE TABLE SQL for ADK internal metadata.""" + return self.__get_create_metadata_table_sql_sync() + + async def _get_seed_metadata_sql(self) -> str: + """Get DuckDB SQL that seeds the ADK schema version metadata row.""" + return self.__get_seed_metadata_sql_sync() + + def _get_drop_app_states_table_sql(self) -> str: + """Get DuckDB DROP TABLE SQL for app-scoped state.""" + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + """Get DuckDB DROP TABLE SQL for user-scoped state.""" + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + """Get DuckDB DROP TABLE SQL for ADK internal metadata.""" + return f"DROP TABLE IF EXISTS {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": """Get DuckDB DROP TABLE SQL statements. @@ -160,13 +343,23 @@ def _get_drop_tables_sql(self) -> "list[str]": Order matters: drop events table (child) before sessions (parent). DuckDB automatically drops indexes when dropping tables. """ - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] def _create_tables(self) -> None: """Synchronous implementation of create_tables.""" with self._config.provide_connection() as conn: conn.execute(self.__get_create_sessions_table_sql_sync()) conn.execute(self.__get_create_events_table_sql_sync()) + conn.execute(self.__get_create_app_states_table_sql_sync()) + conn.execute(self.__get_create_user_states_table_sql_sync()) + conn.execute(self.__get_create_metadata_table_sql_sync()) + conn.execute(self.__get_seed_metadata_sql_sync()) def __get_create_sessions_table_sql_sync(self) -> str: """Synchronous version of DDL generation for use in _create_tables.""" @@ -201,9 +394,43 @@ def __get_create_events_table_sql_sync(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ - async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" - await async_(self._create_tables)() + def __get_create_app_states_table_sql_sync(self) -> str: + """Synchronous DuckDB app-scoped state table DDL.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR PRIMARY KEY, + state JSON NOT NULL, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + """ + + def __get_create_user_states_table_sql_sync(self) -> str: + """Synchronous DuckDB user-scoped state table DDL.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR NOT NULL, + user_id VARCHAR NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ); + """ + + def __get_create_metadata_table_sql_sync(self) -> str: + """Synchronous DuckDB internal metadata table DDL.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR PRIMARY KEY, + value VARCHAR NOT NULL + ); + """ + + def __get_seed_metadata_sql_sync(self) -> str: + """Synchronous DuckDB schema-version metadata seed SQL.""" + return f""" + INSERT OR IGNORE INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + """ def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -235,27 +462,6 @@ def _create_session( id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now ) - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - - Notes: - Uses current UTC timestamp for create_time and update_time. - State is JSON-serialized using SQLSpec serializers. - """ - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": """Synchronous implementation of get_session.""" sql = f""" @@ -294,24 +500,6 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No return None raise - async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None - ) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - renew_for: If positive, touch update_time while reading. - - Returns: - Session record or None if not found. - - Notes: - DuckDB returns datetime objects for TIMESTAMPTZ columns. - JSON is parsed from database storage. - """ - return await async_(self._get_session)(session_id, renew_for) - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Synchronous implementation of update_session_state.""" now = datetime.now(timezone.utc) @@ -327,19 +515,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non conn.execute(sql, (state_json, now, session_id)) conn.commit() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - - Notes: - This replaces the entire state dictionary. - Update time is automatically set to current UTC timestamp. - """ - await async_(self._update_session_state)(session_id, state) - def _delete_session(self, session_id: str) -> None: """Synchronous implementation of delete_session.""" delete_events_sql = f"DELETE FROM {self._events_table} WHERE session_id = ?" @@ -350,17 +525,6 @@ def _delete_session(self, session_id: str) -> None: conn.execute(delete_session_sql, (session_id,)) conn.commit() - async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events. - - Args: - session_id: Session identifier. - - Notes: - DuckDB doesn't support CASCADE in foreign keys, so we manually delete events first. - """ - await async_(self._delete_session)(session_id) - def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": """Synchronous implementation of list_sessions.""" if user_id is None: @@ -401,21 +565,6 @@ def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[S return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - - Notes: - Uses composite index on (app_name, user_id) when user_id is provided. - """ - return await async_(self._list_sessions)(app_name, user_id) - def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" event_data_str = to_json(event_record["event_data"]) @@ -439,15 +588,6 @@ def _append_event(self, event_record: EventRecord) -> None: ) conn.commit() - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record with 5 keys (session_id, invocation_id, - author, timestamp, event_data). - """ - await async_(self._append_event)(event_record) - def _append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> SessionRecord: @@ -499,23 +639,6 @@ def _append_event_and_update_state( update_time=update_time, ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state. - - The event insert and state update succeed together or fail together - within a single DuckDB transaction; the updated SessionRecord is - returned via UPDATE...RETURNING. - - Args: - event_record: Event record to store (5-key shape). - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). - """ - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -557,21 +680,6 @@ def _get_events( return [] raise - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - """ - return await async_(self._get_events)(session_id, after_timestamp, limit) - def _delete_expired_events(self, before: "datetime") -> int: count_sql = f"SELECT COUNT(*) FROM {self._events_table} WHERE timestamp < ?" delete_sql = f"DELETE FROM {self._events_table} WHERE timestamp < ?" @@ -588,10 +696,6 @@ def _delete_expired_events(self, before: "datetime") -> int: return 0 raise - async def delete_expired_events(self, before: "datetime") -> int: - """Delete events older than the given timestamp.""" - return await async_(self._delete_expired_events)(before) - def _delete_idle_sessions(self, updated_before: "datetime") -> int: count_sql = f"SELECT COUNT(*) FROM {self._session_table} WHERE update_time < ?" delete_events_sql = f""" @@ -613,9 +717,84 @@ def _delete_idle_sessions(self, updated_before: "datetime") -> int: return 0 raise - async def delete_idle_sessions(self, updated_before: "datetime") -> int: - """Delete sessions whose update_time predates the given threshold.""" - return await async_(self._delete_idle_sessions)(updated_before) + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Synchronous implementation of get_app_state.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = ?" + + try: + with self._config.provide_connection() as conn: + row = conn.execute(sql, (app_name,)).fetchone() + return from_json(row[0]) if row is not None and isinstance(row[0], str) else (row[0] if row else None) + except Exception as e: + if DUCKDB_TABLE_NOT_FOUND_ERROR in str(e): + return None + raise + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Synchronous implementation of get_user_state.""" + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = ? AND user_id = ?" + + try: + with self._config.provide_connection() as conn: + row = conn.execute(sql, (app_name, user_id)).fetchone() + return from_json(row[0]) if row is not None and isinstance(row[0], str) else (row[0] if row else None) + except Exception as e: + if DUCKDB_TABLE_NOT_FOUND_ERROR in str(e): + return None + raise + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Synchronous implementation of upsert_app_state.""" + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + ON CONFLICT(app_name) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + with self._config.provide_connection() as conn: + conn.execute(sql, (app_name, to_json(state), datetime.now(timezone.utc))) + conn.commit() + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Synchronous implementation of upsert_user_state.""" + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + ON CONFLICT(app_name, user_id) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + with self._config.provide_connection() as conn: + conn.execute(sql, (app_name, user_id, to_json(state), datetime.now(timezone.utc))) + conn.commit() + + def _get_metadata(self, key: str) -> "str | None": + """Synchronous implementation of get_metadata.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE key = ?" + + try: + with self._config.provide_connection() as conn: + row = conn.execute(sql, (key,)).fetchone() + return row[0] if row is not None else None + except Exception as e: + if DUCKDB_TABLE_NOT_FOUND_ERROR in str(e): + return None + raise + + def _set_metadata(self, key: str, value: str) -> None: + """Synchronous implementation of set_metadata.""" + sql = f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES (?, ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """ + + with self._config.provide_connection() as conn: + conn.execute(sql, (key, value)) + conn.commit() class DuckdbADKMemoryStore(BaseAsyncADKMemoryStore["DuckDBConfig"]): @@ -680,6 +859,35 @@ def __init__(self, config: "DuckDBConfig") -> None: """ super().__init__(config) + async def create_tables(self) -> None: + """Create the memory table and indexes if they don't exist.""" + await async_(self._create_tables)() + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication. + + After successful inserts, refreshes the FTS index if FTS is enabled. + """ + return await async_(self._insert_memory_entries)(entries, owner_id) + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query. + + When FTS is enabled, uses ``match_bm25()`` for BM25-ranked results. + Falls back to ILIKE for simple substring matching. + """ + return await async_(self._search_entries)(query, app_name, user_id, limit) + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + def _ensure_fts_extension(self, conn: Any) -> bool: """Ensure the DuckDB FTS extension is available for this connection.""" with contextlib.suppress(Exception): @@ -800,10 +1008,6 @@ def __get_create_memory_table_sql_sync(self) -> str: ON {self._memory_table}(session_id); """ - async def create_tables(self) -> None: - """Create the memory table and indexes if they don't exist.""" - await async_(self._create_tables)() - def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Synchronous implementation of insert_memory_entries.""" if not self._enabled: @@ -874,13 +1078,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication. - - After successful inserts, refreshes the FTS index if FTS is enabled. - """ - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -936,16 +1133,6 @@ def _search_entries( records.append(record) return records - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query. - - When FTS is enabled, uses ``match_bm25()`` for BM25-ranked results. - Falls back to ILIKE for simple substring matching. - """ - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _delete_entries_by_session(self, session_id: str) -> int: """Synchronous implementation of delete_entries_by_session.""" if not self._enabled: @@ -961,10 +1148,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: self._refresh_fts_index(conn) return deleted_count - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: """Synchronous implementation of delete_entries_older_than.""" if not self._enabled: @@ -983,7 +1166,3 @@ def _delete_entries_older_than(self, days: int) -> int: if self._use_fts and deleted_count > 0: self._refresh_fts_index(conn) return deleted_count - - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/mysqlconnector/adk/store.py b/sqlspec/adapters/mysqlconnector/adk/store.py index b92a6da93..374a4ef17 100644 --- a/sqlspec/adapters/mysqlconnector/adk/store.py +++ b/sqlspec/adapters/mysqlconnector/adk/store.py @@ -80,6 +80,62 @@ def _mysql_events_ddl(events_table: str, session_table: str) -> str: """ +def _mysql_app_states_ddl(app_state_table: str) -> str: + """Generate shared MySQL app-scoped state CREATE TABLE DDL.""" + return f""" + CREATE TABLE IF NOT EXISTS {app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_user_states_ddl(user_state_table: str) -> str: + """Generate shared MySQL user-scoped state CREATE TABLE DDL.""" + return f""" + CREATE TABLE IF NOT EXISTS {user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_metadata_ddl(metadata_table: str) -> str: + """Generate shared MySQL ADK internal metadata CREATE TABLE DDL.""" + return f""" + CREATE TABLE IF NOT EXISTS {metadata_table} ( + `key` VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_seed_metadata_sql(metadata_table: str) -> str: + """Generate shared MySQL ADK schema-version seed SQL.""" + return f""" + INSERT IGNORE INTO {metadata_table} (`key`, value) + VALUES ('schema_version', '1') + """ + + +def _state_from_row_value(value: Any) -> "dict[str, Any] | None": + if value is None: + return None + if isinstance(value, str): + return cast("dict[str, Any]", from_json(value)) + return cast("dict[str, Any]", value) + + +def _metadata_from_row_value(value: Any) -> "str | None": + if value is None: + return None + return str(value) + + class MysqlConnectorAsyncADKStore(BaseAsyncADKStore["MysqlConnectorAsyncConfig"]): """MySQL/MariaDB ADK store using mysql-connector async driver. @@ -100,22 +156,14 @@ class MysqlConnectorAsyncADKStore(BaseAsyncADKStore["MysqlConnectorAsyncConfig"] def __init__(self, config: "MysqlConnectorAsyncConfig") -> None: super().__init__(config) - def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": - return _parse_owner_id_column_for_mysql(column_ddl) - - async def _get_create_sessions_table_sql(self) -> str: - return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl) - - async def _get_create_events_table_sql(self) -> str: - return _mysql_events_ddl(self._events_table, self._session_table) - - def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -449,6 +497,141 @@ async def delete_idle_sessions(self, updated_before: "datetime") -> int: return 0 raise + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = %s" + + try: + async with self._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute(sql, (app_name,)) + row = await cursor.fetchone() + finally: + await cursor.close() + return _state_from_row_value(row[0]) if row is not None else None + except mysql.connector.Error as exc: + if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = %s AND user_id = %s" + + try: + async with self._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute(sql, (app_name, user_id)) + row = await cursor.fetchone() + finally: + await cursor.close() + return _state_from_row_value(row[0]) if row is not None else None + except mysql.connector.Error as exc: + if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + async with self._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute(sql, (app_name, to_json(state))) + finally: + await cursor.close() + await conn.commit() + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + async with self._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute(sql, (app_name, user_id, to_json(state))) + finally: + await cursor.close() + await conn.commit() + + async def get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE `key` = %s" + + try: + async with self._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute(sql, (key,)) + row = await cursor.fetchone() + finally: + await cursor.close() + return _metadata_from_row_value(row[0]) if row is not None else None + except mysql.connector.Error as exc: + if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise + + async def set_metadata(self, key: str, value: str) -> None: + sql = f""" + INSERT INTO {self._metadata_table} (`key`, value) + VALUES (%s, %s) + ON DUPLICATE KEY UPDATE value = VALUES(value) + """ + + async with self._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute(sql, (key, value)) + finally: + await cursor.close() + await conn.commit() + + def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": + return _parse_owner_id_column_for_mysql(column_ddl) + + async def _get_create_sessions_table_sql(self) -> str: + return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl) + + async def _get_create_events_table_sql(self) -> str: + return _mysql_events_ddl(self._events_table, self._session_table) + + async def _get_create_app_states_table_sql(self) -> str: + return _mysql_app_states_ddl(self._app_state_table) + + async def _get_create_user_states_table_sql(self) -> str: + return _mysql_user_states_ddl(self._user_state_table) + + async def _get_create_metadata_table_sql(self) -> str: + return _mysql_metadata_ddl(self._metadata_table) + + async def _get_seed_metadata_sql(self) -> str: + return _mysql_seed_metadata_sql(self._metadata_table) + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + class MysqlConnectorSyncADKStore(BaseAsyncADKStore["MysqlConnectorSyncConfig"]): """MySQL/MariaDB ADK store using mysql-connector sync driver. @@ -470,6 +653,82 @@ class MysqlConnectorSyncADKStore(BaseAsyncADKStore["MysqlConnectorSyncConfig"]): def __init__(self, config: "MysqlConnectorSyncConfig") -> None: super().__init__(config) + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id, renew_for) + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> SessionRecord: + """Atomically append an event and update the session's durable state.""" + return await async_(self._append_event_and_update_state)(event_record, session_id, state) + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return await async_(self._get_app_state)(app_name) + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return await async_(self._get_user_state)(app_name, user_id) + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + await async_(self._upsert_app_state)(app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + await async_(self._upsert_user_state)(app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return await async_(self._get_metadata)(key) + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + await async_(self._set_metadata)(key, value) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": return _parse_owner_id_column_for_mysql(column_ddl) @@ -479,17 +738,44 @@ async def _get_create_sessions_table_sql(self) -> str: async def _get_create_events_table_sql(self) -> str: return _mysql_events_ddl(self._events_table, self._session_table) + async def _get_create_app_states_table_sql(self) -> str: + return _mysql_app_states_ddl(self._app_state_table) + + async def _get_create_user_states_table_sql(self) -> str: + return _mysql_user_states_ddl(self._user_state_table) + + async def _get_create_metadata_table_sql(self) -> str: + return _mysql_metadata_ddl(self._metadata_table) + + async def _get_seed_metadata_sql(self) -> str: + return _mysql_seed_metadata_sql(self._metadata_table) + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] def _create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(run_(self._get_create_sessions_table_sql)()) driver.execute_script(run_(self._get_create_events_table_sql)()) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(run_(self._get_create_app_states_table_sql)()) + driver.execute_script(run_(self._get_create_user_states_table_sql)()) + driver.execute_script(run_(self._get_create_metadata_table_sql)()) + driver.execute_script(run_(self._get_seed_metadata_sql)()) def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -524,12 +810,6 @@ def _create_session( raise RuntimeError(msg) return result - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -569,12 +849,6 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No return None raise - async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None - ) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: state_json = to_json(state) @@ -592,10 +866,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non cursor.close() conn.commit() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state.""" - await async_(self._update_session_state)(session_id, state) - def _delete_session(self, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE id = %s" @@ -607,10 +877,6 @@ def _delete_session(self, session_id: str) -> None: cursor.close() conn.commit() - async def delete_session(self, session_id: str) -> None: - """Delete session and associated events.""" - await async_(self._delete_session)(session_id) - def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: sql = f""" @@ -654,10 +920,6 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) - def _append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> SessionRecord: @@ -728,12 +990,6 @@ def _append_event_and_update_state( update_time=cast("datetime", row[5]), ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _insert_event(self, event_record: EventRecord) -> None: event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data @@ -816,12 +1072,6 @@ def _get_events( return [] raise - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) - def _delete_expired_events(self, before: "datetime") -> int: sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" @@ -839,10 +1089,6 @@ def _delete_expired_events(self, before: "datetime") -> int: return 0 raise - async def delete_expired_events(self, before: "datetime") -> int: - """Delete events older than the given timestamp.""" - return await async_(self._delete_expired_events)(before) - def _delete_idle_sessions(self, updated_before: "datetime") -> int: sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" @@ -860,18 +1106,106 @@ def _delete_idle_sessions(self, updated_before: "datetime") -> int: return 0 raise - async def delete_idle_sessions(self, updated_before: "datetime") -> int: - """Delete sessions whose update_time predates the given threshold.""" - return await async_(self._delete_idle_sessions)(updated_before) + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = %s" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (app_name,)) + row = cursor.fetchone() + finally: + cursor.close() + return _state_from_row_value(row[0]) if row is not None else None + except mysql.connector.Error as exc: + if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = %s AND user_id = %s" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (app_name, user_id)) + row = cursor.fetchone() + finally: + cursor.close() + return _state_from_row_value(row[0]) if row is not None else None + except mysql.connector.Error as exc: + if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (app_name, to_json(state))) + finally: + cursor.close() + conn.commit() + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (app_name, user_id, to_json(state))) + finally: + cursor.close() + conn.commit() + + def _get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE `key` = %s" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (key,)) + row = cursor.fetchone() + finally: + cursor.close() + return _metadata_from_row_value(row[0]) if row is not None else None + except mysql.connector.Error as exc: + if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise + + def _set_metadata(self, key: str, value: str) -> None: + sql = f""" + INSERT INTO {self._metadata_table} (`key`, value) + VALUES (%s, %s) + ON DUPLICATE KEY UPDATE value = VALUES(value) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (key, value)) + finally: + cursor.close() + conn.commit() def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) - class MysqlConnectorAsyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorAsyncConfig"]): """MySQL/MariaDB ADK memory store using mysql-connector async driver.""" @@ -881,40 +1215,6 @@ class MysqlConnectorAsyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorA def __init__(self, config: "MysqlConnectorAsyncConfig") -> None: super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: - owner_id_line = "" - fk_constraint = "" - if self._owner_id_column_ddl: - col_def, fk_def = _parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_line = f",\n {col_def}" - if fk_def: - fk_constraint = f",\n {fk_def}" - - fts_index = "" - if self._use_fts: - fts_index = f",\n FULLTEXT INDEX idx_{self._memory_table}_fts (content_text)" - - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( - id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content_json JSON NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSON, - inserted_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - INDEX idx_{self._memory_table}_app_user_time (app_name, user_id, timestamp), - INDEX idx_{self._memory_table}_session (session_id){fts_index}{fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - def _get_drop_memory_table_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._memory_table}"] - async def create_tables(self) -> None: if not self._enabled: return @@ -1061,6 +1361,40 @@ async def delete_entries_older_than(self, days: int) -> int: finally: await cursor.close() + async def _get_create_memory_table_sql(self) -> str: + owner_id_line = "" + fk_constraint = "" + if self._owner_id_column_ddl: + col_def, fk_def = _parse_owner_id_column_for_mysql(self._owner_id_column_ddl) + owner_id_line = f",\n {col_def}" + if fk_def: + fk_constraint = f",\n {fk_def}" + + fts_index = "" + if self._use_fts: + fts_index = f",\n FULLTEXT INDEX idx_{self._memory_table}_fts (content_text)" + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + content_json JSON NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSON, + inserted_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + INDEX idx_{self._memory_table}_app_user_time (app_name, user_id, timestamp), + INDEX idx_{self._memory_table}_session (session_id){fts_index}{fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + return [f"DROP TABLE IF EXISTS {self._memory_table}"] + class MysqlConnectorSyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorSyncConfig"]): """MySQL/MariaDB ADK memory store using mysql-connector sync driver.""" @@ -1070,6 +1404,28 @@ class MysqlConnectorSyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorSy def __init__(self, config: "MysqlConnectorSyncConfig") -> None: super().__init__(config) + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + async def _get_create_memory_table_sql(self) -> str: owner_id_line = "" fk_constraint = "" @@ -1111,10 +1467,6 @@ def _create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(run_(self._get_create_memory_table_sql)()) - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -1181,10 +1533,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec conn.commit() return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -1225,12 +1573,6 @@ def _search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _delete_entries_by_session(self, session_id: str) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -1246,10 +1588,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -1267,7 +1605,3 @@ def _delete_entries_older_than(self, days: int) -> int: return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 finally: cursor.close() - - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index 24ccf5c91..1ae0a15e3 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -112,8 +112,8 @@ def __init__(self, config: "OracleAsyncConfig") -> None: Notes: Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") + - session_table: Sessions table name (default: "adk_session") + - events_table: Events table name (default: "adk_event") - owner_id_column: Optional owner FK column DDL (default: None) - in_memory: Enable INMEMORY PRIORITY HIGH clause (default: False) """ @@ -124,721 +124,951 @@ def __init__(self, config: "OracleAsyncConfig") -> None: adk_config = config.extension_config.get("adk", {}) self._in_memory: bool = bool(adk_config.get("in_memory", False)) - async def _get_create_sessions_table_sql(self) -> str: - """Get Oracle CREATE TABLE SQL for sessions table. + async def create_tables(self) -> None: + """Create both sessions and events tables if they don't exist. - Auto-detects optimal JSON storage type based on Oracle version. - Result is cached to minimize database queries. + Notes: + Detects Oracle version to determine optimal JSON storage type. + Uses version-appropriate table schema. """ storage_type = await self._detect_json_storage_type() - return self._get_create_sessions_table_sql_for_type(storage_type) + logger.debug("Creating ADK tables with storage type: %s", storage_type) - async def _get_create_events_table_sql(self) -> str: - """Get Oracle CREATE TABLE SQL for events table. + async with self._config.provide_session() as driver: + await driver.execute_script(self._get_create_sessions_table_sql_for_type(storage_type)) - Auto-detects optimal JSON storage type based on Oracle version. - Result is cached to minimize database queries. - """ - storage_type = await self._detect_json_storage_type() - return self._get_create_events_table_sql_for_type(storage_type) + await driver.execute_script(self._get_create_events_table_sql_for_type(storage_type)) + await driver.execute_script(self._get_create_app_states_table_sql_for_type(storage_type)) + await driver.execute_script(self._get_create_user_states_table_sql_for_type(storage_type)) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) - async def _detect_json_storage_type(self) -> JSONStorageType: - """Detect the appropriate JSON storage type based on Oracle version. + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session. + + Args: + session_id: Unique session identifier. + app_name: Application name. + user_id: User identifier. + state: Initial session state. + owner_id: Optional owner ID value for owner_id_column (if configured). Returns: - Appropriate JSONStorageType for this Oracle version. + Created session record. Notes: - Queries product_component_version to determine Oracle version. - - Oracle 21c+ with compatible >= 20: Native JSON type - - Oracle 12c+: BLOB with IS JSON constraint - - Oracle 11g and earlier: plain BLOB - - Result is cached in self._json_storage_type. + Uses SYSTIMESTAMP for create_time and update_time. + State is serialized using version-appropriate format. + owner_id is ignored if owner_id_column not configured. """ - if self._json_storage_type is not None: - return self._json_storage_type - - version_info = await self._get_version_info() - self._json_storage_type = _storage_type_from_version(version_info) - return self._json_storage_type - - async def _get_version_info(self) -> "OracleVersionInfo | None": - """Return cached Oracle version info using Oracle data dictionary.""" - - if self._oracle_version_info is not None: - return self._oracle_version_info + state_data = await self._serialize_state(state) - async with self._config.provide_session() as driver: - dictionary = OracledbAsyncDataDictionary() - self._oracle_version_info = await dictionary.get_version(driver) + if self._owner_id_column_name: + sql = f""" + INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time, {self._owner_id_column_name}) + VALUES (:id, :app_name, :user_id, :state, SYSTIMESTAMP, SYSTIMESTAMP, :owner_id) + """ + params = { + "id": session_id, + "app_name": app_name, + "user_id": user_id, + "state": state_data, + "owner_id": owner_id, + } + else: + sql = f""" + INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time) + VALUES (:id, :app_name, :user_id, :state, SYSTIMESTAMP, SYSTIMESTAMP) + """ + params = {"id": session_id, "app_name": app_name, "user_id": user_id, "state": state_data} - if self._oracle_version_info is None: - logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON storage") + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, params) + await conn.commit() - return self._oracle_version_info + return await self.get_session(session_id) # type: ignore[return-value] - async def _serialize_state(self, state: "dict[str, Any]") -> "str | bytes": - """Serialize state dictionary to appropriate format based on storage type. + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID. Args: - state: State dictionary to serialize. + session_id: Session identifier. + renew_for: If positive, touch update_time while reading. Returns: - JSON string for JSON_NATIVE, bytes for BLOB types. - """ - storage_type = await self._detect_json_storage_type() + Session record or None if not found. - if storage_type == JSONStorageType.JSON_NATIVE: - return to_json(state) + Notes: + Oracle returns datetime objects for TIMESTAMP columns. + State is deserialized using version-appropriate format. + """ - return to_json(state, as_bytes=True) + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + await cursor.execute( + f"UPDATE {self._session_table} SET update_time = SYSTIMESTAMP WHERE id = :id", + {"id": session_id}, + ) + await conn.commit() - async def _deserialize_state(self, data: Any) -> "dict[str, Any]": - """Deserialize state data from database format. + await cursor.execute( + f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE id = :id + """, + {"id": session_id}, + ) + row = await cursor.fetchone() - Args: - data: Data from database (may be LOB, str, bytes, or dict). + if row is None: + return None - Returns: - Deserialized state dictionary. + session_id_val, app_name, user_id, state_data, create_time, update_time = row - Notes: - Handles LOB reading if data has read() method. - Oracle JSON type may return dict directly. - """ - if is_async_readable(data): - data = await data.read() - elif is_readable(data): - data = data.read() + state = await self._deserialize_state(state_data) - if isinstance(data, dict): - return cast("dict[str, Any]", _coerce_decimal_values(data)) + return SessionRecord( + id=session_id_val, + app_name=app_name, + user_id=user_id, + state=state, + create_time=create_time, + update_time=update_time, + ) + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return None + raise - if isinstance(data, bytes): - return from_json(data) # type: ignore[no-any-return] + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state. - if isinstance(data, str): - return from_json(data) # type: ignore[no-any-return] + Args: + session_id: Session identifier. + state: New state dictionary (replaces existing state). - return from_json(str(data)) # type: ignore[no-any-return] + Notes: + This replaces the entire state dictionary. + Updates update_time to current timestamp. + State is serialized using version-appropriate format. + """ + state_data = await self._serialize_state(state) - async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": - """Deserialize JSON payloads from Oracle JSON/BLOB/LOB values.""" - if data is None: - return None - return await self._deserialize_state(data) + sql = f""" + UPDATE {self._session_table} + SET state = :state, update_time = SYSTIMESTAMP + WHERE id = :id + """ - async def _serialize_event_data(self, event_data: Any) -> "str | bytes": - """Serialize event_data to the configured Oracle JSON storage format.""" - storage_type = await self._detect_json_storage_type() - if storage_type == JSONStorageType.JSON_NATIVE: - return to_json(event_data) - return to_json(event_data, as_bytes=True) + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"state": state_data, "id": session_id}) + await conn.commit() - async def _read_event_data(self, data: Any) -> str: - """Read event_data from database, handling LOB types. + async def delete_session(self, session_id: str) -> None: + """Delete session and all associated events (cascade). Args: - data: Data from database (may be LOB, str, or dict). + session_id: Session identifier. - Returns: - JSON string. + Notes: + Foreign key constraint ensures events are cascade-deleted. """ - if is_async_readable(data): - data = await data.read() - elif is_readable(data): - data = data.read() - - if isinstance(data, dict): - return to_json(data) - - if isinstance(data, bytes): - return data.decode("utf-8") + sql = f"DELETE FROM {self._session_table} WHERE id = :id" - return str(data) + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"id": session_id}) + await conn.commit() - def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) -> str: - """Get Oracle CREATE TABLE SQL for sessions with specified storage type. + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app, optionally filtered by user. Args: - storage_type: JSON storage type to use. + app_name: Application name. + user_id: User identifier. If None, lists all sessions for the app. Returns: - SQL statement to create adk_sessions table. + List of session records ordered by update_time DESC. + + Notes: + Uses composite index on (app_name, user_id) when user_id is provided. + State is deserialized using version-appropriate format. """ - if storage_type == JSONStorageType.JSON_NATIVE: - state_column = "state JSON NOT NULL" - elif storage_type == JSONStorageType.BLOB_JSON: - state_column = "state BLOB CHECK (state IS JSON) NOT NULL" + + if user_id is None: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = :app_name + ORDER BY update_time DESC + """ + params = {"app_name": app_name} else: - state_column = "state BLOB NOT NULL" + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = :app_name AND user_id = :user_id + ORDER BY update_time DESC + """ + params = {"app_name": app_name, "user_id": user_id} - owner_id_column_sql = f", {self._owner_id_column_ddl}" if self._owner_id_column_ddl else "" - table_clauses = _oracle_table_feature_clauses( - self._config, - "session", - in_memory=self._in_memory, - hash_partition_key="id", - range_partition_key="create_time", - ) + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, params) + rows = await cursor.fetchall() - return f""" - BEGIN - EXECUTE IMMEDIATE 'CREATE TABLE {self._session_table} ( - id VARCHAR2(128) PRIMARY KEY, - app_name VARCHAR2(128) NOT NULL, - user_id VARCHAR2(128) NOT NULL, - {state_column}, - create_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, - update_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL{owner_id_column_sql} - ){table_clauses}'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -955 THEN - RAISE; - END IF; - END; + results = [] + for row in rows: + state = await self._deserialize_state(row[3]) - BEGIN - EXECUTE IMMEDIATE 'CREATE INDEX idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id)'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -955 THEN - RAISE; - END IF; - END; + results.append( + SessionRecord( + id=row[0], + app_name=row[1], + user_id=row[2], + state=state, + create_time=row[4], + update_time=row[5], + ) + ) + return results + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return [] + raise - BEGIN - EXECUTE IMMEDIATE 'CREATE INDEX idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC)'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -955 THEN - RAISE; - END IF; - END; + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session. + + Args: + event_record: Event record with 5 keys: session_id, invocation_id, + author, timestamp, event_data. + """ + sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_data + ) VALUES ( + :session_id, :invocation_id, :author, :timestamp, :event_data + ) """ - def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) -> str: - """Get Oracle CREATE TABLE SQL for events with specified storage type. + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute( + sql, + { + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "author": event_record["author"], + "timestamp": event_record["timestamp"], + "event_data": await self._serialize_event_data(event_record["event_data"]), + }, + ) + await conn.commit() - The events table uses the new 5-column contract: session_id, invocation_id, - author, timestamp, and event_data. The event_data column stores the full - ADK Event as JSON (21c+) or BLOB (older versions). + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> SessionRecord: + """Atomically append an event and update the session's durable state. - Args: - storage_type: JSON storage type to use. + Both the event insert and session state update are executed within a + single transaction so they succeed or fail together. The refreshed + SessionRecord is read inside the same transaction (Oracle's RETURNING + INTO requires output bind variables which complicate async cursor + handling, so a SELECT-after-UPDATE is used instead). - Returns: - SQL statement to create adk_events table. + Args: + event_record: Event record with 5 keys: session_id, invocation_id, + author, timestamp, event_data. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). """ - event_data_col = _event_data_column_ddl(storage_type) - table_clauses = _oracle_table_feature_clauses( - self._config, - "events", - in_memory=self._in_memory, - hash_partition_key="session_id", - range_partition_key="timestamp", + insert_sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_data + ) VALUES ( + :session_id, :invocation_id, :author, :timestamp, :event_data ) + """ - return f""" - BEGIN - EXECUTE IMMEDIATE 'CREATE TABLE {self._events_table} ( - session_id VARCHAR2(128) NOT NULL, - invocation_id VARCHAR2(256), - author VARCHAR2(256), - timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, - {event_data_col}, - CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) - REFERENCES {self._session_table}(id) ON DELETE CASCADE - ){table_clauses}'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -955 THEN - RAISE; - END IF; - END; + state_data = await self._serialize_state(state) + update_sql = f""" + UPDATE {self._session_table} + SET state = :state, update_time = SYSTIMESTAMP + WHERE id = :id + """ - BEGIN - EXECUTE IMMEDIATE 'CREATE INDEX idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC)'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -955 THEN - RAISE; - END IF; - END; + select_sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE id = :id """ - def _get_drop_tables_sql(self) -> "list[str]": - """Get Oracle DROP TABLE SQL statements. + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute( + insert_sql, + { + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "author": event_record["author"], + "timestamp": event_record["timestamp"], + "event_data": await self._serialize_event_data(event_record["event_data"]), + }, + ) + await cursor.execute(update_sql, {"state": state_data, "id": session_id}) + await cursor.execute(select_sql, {"id": session_id}) + row = await cursor.fetchone() + await conn.commit() - Returns: - List of SQL statements to drop tables and indexes. + if row is None: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) - Notes: - Order matters: drop events table (child) before sessions (parent). - Oracle automatically drops indexes when dropping tables. - """ - return [ - f""" - BEGIN - EXECUTE IMMEDIATE 'DROP INDEX idx_{self._events_table}_session'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -1418 THEN - RAISE; - END IF; - END; - """, - f""" - BEGIN - EXECUTE IMMEDIATE 'DROP INDEX idx_{self._session_table}_update_time'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -1418 THEN - RAISE; - END IF; - END; - """, - f""" - BEGIN - EXECUTE IMMEDIATE 'DROP INDEX idx_{self._session_table}_app_user'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -1418 THEN - RAISE; - END IF; - END; - """, - f""" - BEGIN - EXECUTE IMMEDIATE 'DROP TABLE {self._events_table}'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -942 THEN - RAISE; - END IF; - END; - """, - f""" - BEGIN - EXECUTE IMMEDIATE 'DROP TABLE {self._session_table}'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -942 THEN - RAISE; - END IF; - END; - """, - ] + session_id_val, app_name, user_id, state_data_row, create_time, update_time = row + return SessionRecord( + id=session_id_val, + app_name=app_name, + user_id=user_id, + state=await self._deserialize_state(state_data_row), + create_time=create_time, + update_time=update_time, + ) - async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist. + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session. - Notes: - Detects Oracle version to determine optimal JSON storage type. - Uses version-appropriate table schema. + Args: + session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. + + Returns: + List of event records ordered by timestamp ASC. """ - storage_type = await self._detect_json_storage_type() - logger.debug("Creating ADK tables with storage type: %s", storage_type) - async with self._config.provide_session() as driver: - await driver.execute_script(self._get_create_sessions_table_sql_for_type(storage_type)) + where_clauses = ["session_id = :session_id"] + params: dict[str, Any] = {"session_id": session_id} - await driver.execute_script(self._get_create_events_table_sql_for_type(storage_type)) + if after_timestamp is not None: + where_clauses.append("timestamp > :after_timestamp") + params["after_timestamp"] = after_timestamp - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session. + where_clause = " AND ".join(where_clauses) + limit_clause = "" + if limit: + limit_clause = f" FETCH FIRST {limit} ROWS ONLY" - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). + sql = f""" + SELECT session_id, invocation_id, author, timestamp, event_data + FROM {self._events_table} + WHERE {where_clause} + ORDER BY timestamp ASC{limit_clause} + """ - Returns: - Created session record. + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, params) + rows = await cursor.fetchall() - Notes: - Uses SYSTIMESTAMP for create_time and update_time. - State is serialized using version-appropriate format. - owner_id is ignored if owner_id_column not configured. - """ - state_data = await self._serialize_state(state) + return [ + EventRecord( + session_id=row[0], + invocation_id=_oracle_text_value(row[1]), + author=_oracle_text_value(row[2]), + timestamp=row[3], + event_data=await self._deserialize_json_field(row[4]) or {}, + ) + for row in rows + ] + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return [] + raise - if self._owner_id_column_name: - sql = f""" - INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time, {self._owner_id_column_name}) - VALUES (:id, :app_name, :user_id, :state, SYSTIMESTAMP, SYSTIMESTAMP, :owner_id) - """ - params = { - "id": session_id, - "app_name": app_name, - "user_id": user_id, - "state": state_data, - "owner_id": owner_id, - } - else: - sql = f""" - INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time) - VALUES (:id, :app_name, :user_id, :state, SYSTIMESTAMP, SYSTIMESTAMP) - """ - params = {"id": session_id, "app_name": app_name, "user_id": user_id, "state": state_data} + async def delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < :before" - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, params) - await conn.commit() + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"before": before}) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return 0 + raise - return await self.get_session(session_id) # type: ignore[return-value] + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < :updated_before" - async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None - ) -> "SessionRecord | None": - """Get session by ID. + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"updated_before": updated_before}) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return 0 + raise - Args: - session_id: Session identifier. - renew_for: If positive, touch update_time while reading. + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = :app_name" - Returns: - Session record or None if not found. + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"app_name": app_name}) + row = await cursor.fetchone() + return await self._deserialize_state(row[0]) if row is not None else None + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return None + raise - Notes: - Oracle returns datetime objects for TIMESTAMP columns. - State is deserialized using version-appropriate format. + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + sql = f""" + SELECT state + FROM {self._user_state_table} + WHERE app_name = :app_name AND user_id = :user_id """ try: async with self._config.provide_connection() as conn: cursor = conn.cursor() - if renew_for is not None and self._calculate_expires_at(renew_for) is not None: - await cursor.execute( - f"UPDATE {self._session_table} SET update_time = SYSTIMESTAMP WHERE id = :id", - {"id": session_id}, - ) - await conn.commit() - - await cursor.execute( - f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = :id - """, - {"id": session_id}, - ) + await cursor.execute(sql, {"app_name": app_name, "user_id": user_id}) row = await cursor.fetchone() + return await self._deserialize_state(row[0]) if row is not None else None + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return None + raise - if row is None: - return None + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + sql = f""" + MERGE INTO {self._app_state_table} target + USING (SELECT :app_name AS app_name, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, state, update_time) + VALUES (source.app_name, source.state, SYSTIMESTAMP) + """ - session_id_val, app_name, user_id, state_data, create_time, update_time = row + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"app_name": app_name, "state": await self._serialize_state(state)}) + await conn.commit() - state = await self._deserialize_state(state_data) + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + sql = f""" + MERGE INTO {self._user_state_table} target + USING (SELECT :app_name AS app_name, :user_id AS user_id, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name AND target.user_id = source.user_id) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, user_id, state, update_time) + VALUES (source.app_name, source.user_id, source.state, SYSTIMESTAMP) + """ - return SessionRecord( - id=session_id_val, - app_name=app_name, - user_id=user_id, - state=state, - create_time=create_time, - update_time=update_time, - ) + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute( + sql, {"app_name": app_name, "user_id": user_id, "state": await self._serialize_state(state)} + ) + await conn.commit() + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE key = :key" + + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"key": key}) + row = await cursor.fetchone() + return str(row[0]) if row is not None else None except oracledb.DatabaseError as e: error_obj = e.args[0] if e.args else None if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: return None raise - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + sql = f""" + MERGE INTO {self._metadata_table} target + USING (SELECT :key AS key, :value AS value FROM DUAL) source + ON (target.key = source.key) + WHEN MATCHED THEN + UPDATE SET target.value = source.value + WHEN NOT MATCHED THEN + INSERT (key, value) + VALUES (source.key, source.value) + """ + + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"key": key, "value": value}) + await conn.commit() + + async def _get_create_sessions_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for sessions table. + + Auto-detects optimal JSON storage type based on Oracle version. + Result is cached to minimize database queries. + """ + storage_type = await self._detect_json_storage_type() + return self._get_create_sessions_table_sql_for_type(storage_type) + + async def _get_create_events_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for events table. + + Auto-detects optimal JSON storage type based on Oracle version. + Result is cached to minimize database queries. + """ + storage_type = await self._detect_json_storage_type() + return self._get_create_events_table_sql_for_type(storage_type) + + async def _get_create_app_states_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for app-scoped state.""" + storage_type = await self._detect_json_storage_type() + return self._get_create_app_states_table_sql_for_type(storage_type) + + async def _get_create_user_states_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for user-scoped state.""" + storage_type = await self._detect_json_storage_type() + return self._get_create_user_states_table_sql_for_type(storage_type) + + async def _get_create_metadata_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for ADK internal metadata.""" + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._metadata_table} ( + key VARCHAR2(128) PRIMARY KEY, + value VARCHAR2(512) NOT NULL + )'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + + async def _get_seed_metadata_sql(self) -> str: + """Get Oracle SQL to seed the ADK schema-version metadata row.""" + return f""" + BEGIN + INSERT INTO {self._metadata_table} (key, value) + SELECT 'schema_version', '1' + FROM DUAL + WHERE NOT EXISTS ( + SELECT 1 FROM {self._metadata_table} WHERE key = 'schema_version' + ); + END; + """ - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). + async def _detect_json_storage_type(self) -> JSONStorageType: + """Detect the appropriate JSON storage type based on Oracle version. + + Returns: + Appropriate JSONStorageType for this Oracle version. Notes: - This replaces the entire state dictionary. - Updates update_time to current timestamp. - State is serialized using version-appropriate format. - """ - state_data = await self._serialize_state(state) + Queries product_component_version to determine Oracle version. + - Oracle 21c+ with compatible >= 20: Native JSON type + - Oracle 12c+: BLOB with IS JSON constraint + - Oracle 11g and earlier: plain BLOB - sql = f""" - UPDATE {self._session_table} - SET state = :state, update_time = SYSTIMESTAMP - WHERE id = :id + Result is cached in self._json_storage_type. """ + if self._json_storage_type is not None: + return self._json_storage_type - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, {"state": state_data, "id": session_id}) - await conn.commit() + version_info = await self._get_version_info() + self._json_storage_type = _storage_type_from_version(version_info) + return self._json_storage_type - async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). + async def _get_version_info(self) -> "OracleVersionInfo | None": + """Return cached Oracle version info using Oracle data dictionary.""" + + if self._oracle_version_info is not None: + return self._oracle_version_info + + async with self._config.provide_session() as driver: + dictionary = OracledbAsyncDataDictionary() + self._oracle_version_info = await dictionary.get_version(driver) + + if self._oracle_version_info is None: + logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON storage") + + return self._oracle_version_info + + async def _serialize_state(self, state: "dict[str, Any]") -> "str | bytes": + """Serialize state dictionary to appropriate format based on storage type. Args: - session_id: Session identifier. + state: State dictionary to serialize. - Notes: - Foreign key constraint ensures events are cascade-deleted. + Returns: + JSON string for JSON_NATIVE, bytes for BLOB types. """ - sql = f"DELETE FROM {self._session_table} WHERE id = :id" + storage_type = await self._detect_json_storage_type() - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, {"id": session_id}) - await conn.commit() + if storage_type == JSONStorageType.JSON_NATIVE: + return to_json(state) - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. + return to_json(state, as_bytes=True) + + async def _deserialize_state(self, data: Any) -> "dict[str, Any]": + """Deserialize state data from database format. Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. + data: Data from database (may be LOB, str, bytes, or dict). Returns: - List of session records ordered by update_time DESC. + Deserialized state dictionary. Notes: - Uses composite index on (app_name, user_id) when user_id is provided. - State is deserialized using version-appropriate format. + Handles LOB reading if data has read() method. + Oracle JSON type may return dict directly. """ + if is_async_readable(data): + data = await data.read() + elif is_readable(data): + data = data.read() - if user_id is None: - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE app_name = :app_name - ORDER BY update_time DESC - """ - params = {"app_name": app_name} - else: - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE app_name = :app_name AND user_id = :user_id - ORDER BY update_time DESC - """ - params = {"app_name": app_name, "user_id": user_id} + if isinstance(data, dict): + return cast("dict[str, Any]", _coerce_decimal_values(data)) - try: - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, params) - rows = await cursor.fetchall() + if isinstance(data, bytes): + return from_json(data) # type: ignore[no-any-return] - results = [] - for row in rows: - state = await self._deserialize_state(row[3]) + if isinstance(data, str): + return from_json(data) # type: ignore[no-any-return] - results.append( - SessionRecord( - id=row[0], - app_name=row[1], - user_id=row[2], - state=state, - create_time=row[4], - update_time=row[5], - ) - ) - return results - except oracledb.DatabaseError as e: - error_obj = e.args[0] if e.args else None - if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: - return [] - raise + return from_json(str(data)) # type: ignore[no-any-return] - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. + async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": + """Deserialize JSON payloads from Oracle JSON/BLOB/LOB values.""" + if data is None: + return None + return await self._deserialize_state(data) + + async def _serialize_event_data(self, event_data: Any) -> "str | bytes": + """Serialize event_data to the configured Oracle JSON storage format.""" + storage_type = await self._detect_json_storage_type() + if storage_type == JSONStorageType.JSON_NATIVE: + return to_json(event_data) + return to_json(event_data, as_bytes=True) + + async def _read_event_data(self, data: Any) -> str: + """Read event_data from database, handling LOB types. Args: - event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_data. - """ - sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data - ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_data - ) + data: Data from database (may be LOB, str, or dict). + + Returns: + JSON string. """ + if is_async_readable(data): + data = await data.read() + elif is_readable(data): + data = data.read() - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute( - sql, - { - "session_id": event_record["session_id"], - "invocation_id": event_record["invocation_id"], - "author": event_record["author"], - "timestamp": event_record["timestamp"], - "event_data": await self._serialize_event_data(event_record["event_data"]), - }, - ) - await conn.commit() + if isinstance(data, dict): + return to_json(data) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state. + if isinstance(data, bytes): + return data.decode("utf-8") - Both the event insert and session state update are executed within a - single transaction so they succeed or fail together. The refreshed - SessionRecord is read inside the same transaction (Oracle's RETURNING - INTO requires output bind variables which complicate async cursor - handling, so a SELECT-after-UPDATE is used instead). + return str(data) + + def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) -> str: + """Get Oracle CREATE TABLE SQL for sessions with specified storage type. Args: - event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_data. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). + storage_type: JSON storage type to use. + + Returns: + SQL statement to create adk_session table. """ - insert_sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data - ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_data + if storage_type == JSONStorageType.JSON_NATIVE: + state_column = "state JSON NOT NULL" + elif storage_type == JSONStorageType.BLOB_JSON: + state_column = "state BLOB CHECK (state IS JSON) NOT NULL" + else: + state_column = "state BLOB NOT NULL" + + owner_id_column_sql = f", {self._owner_id_column_ddl}" if self._owner_id_column_ddl else "" + table_clauses = _oracle_table_feature_clauses( + self._config, + "session", + in_memory=self._in_memory, + hash_partition_key="id", + range_partition_key="create_time", ) - """ - state_data = await self._serialize_state(state) - update_sql = f""" - UPDATE {self._session_table} - SET state = :state, update_time = SYSTIMESTAMP - WHERE id = :id + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._session_table} ( + id VARCHAR2(128) PRIMARY KEY, + app_name VARCHAR2(128) NOT NULL, + user_id VARCHAR2(128) NOT NULL, + {state_column}, + create_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, + update_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL{owner_id_column_sql} + ){table_clauses}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + + BEGIN + EXECUTE IMMEDIATE 'CREATE INDEX idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id)'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + + BEGIN + EXECUTE IMMEDIATE 'CREATE INDEX idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC)'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; """ - select_sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = :id - """ + def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) -> str: + """Get Oracle CREATE TABLE SQL for events with specified storage type. - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute( - insert_sql, - { - "session_id": event_record["session_id"], - "invocation_id": event_record["invocation_id"], - "author": event_record["author"], - "timestamp": event_record["timestamp"], - "event_data": await self._serialize_event_data(event_record["event_data"]), - }, - ) - await cursor.execute(update_sql, {"state": state_data, "id": session_id}) - await cursor.execute(select_sql, {"id": session_id}) - row = await cursor.fetchone() - await conn.commit() + The events table uses the new 5-column contract: session_id, invocation_id, + author, timestamp, and event_data. The event_data column stores the full + ADK Event as JSON (21c+) or BLOB (older versions). - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) + Args: + storage_type: JSON storage type to use. - session_id_val, app_name, user_id, state_data_row, create_time, update_time = row - return SessionRecord( - id=session_id_val, - app_name=app_name, - user_id=user_id, - state=await self._deserialize_state(state_data_row), - create_time=create_time, - update_time=update_time, + Returns: + SQL statement to create adk_event table. + """ + event_data_col = _event_data_column_ddl(storage_type) + table_clauses = _oracle_table_feature_clauses( + self._config, + "events", + in_memory=self._in_memory, + hash_partition_key="session_id", + range_partition_key="timestamp", ) - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._events_table} ( + session_id VARCHAR2(128) NOT NULL, + invocation_id VARCHAR2(256), + author VARCHAR2(256), + timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, + {event_data_col}, + CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) + REFERENCES {self._session_table}(id) ON DELETE CASCADE + ){table_clauses}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; - Returns: - List of event records ordered by timestamp ASC. + BEGIN + EXECUTE IMMEDIATE 'CREATE INDEX idx_{self._events_table}_session + ON {self._events_table}(session_id, timestamp ASC)'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; """ - where_clauses = ["session_id = :session_id"] - params: dict[str, Any] = {"session_id": session_id} + def _get_create_app_states_table_sql_for_type(self, storage_type: JSONStorageType) -> str: + """Get Oracle CREATE TABLE SQL for app-scoped state with specified storage type.""" + state_column = _json_column_ddl("state", storage_type) - if after_timestamp is not None: - where_clauses.append("timestamp > :after_timestamp") - params["after_timestamp"] = after_timestamp + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._app_state_table} ( + app_name VARCHAR2(128) PRIMARY KEY, + {state_column}, + update_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL + )'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ - where_clause = " AND ".join(where_clauses) - limit_clause = "" - if limit: - limit_clause = f" FETCH FIRST {limit} ROWS ONLY" + def _get_create_user_states_table_sql_for_type(self, storage_type: JSONStorageType) -> str: + """Get Oracle CREATE TABLE SQL for user-scoped state with specified storage type.""" + state_column = _json_column_ddl("state", storage_type) - sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} - WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._user_state_table} ( + app_name VARCHAR2(128) NOT NULL, + user_id VARCHAR2(128) NOT NULL, + {state_column}, + update_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, + PRIMARY KEY (app_name, user_id) + )'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; """ - try: - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, params) - rows = await cursor.fetchall() + def _get_drop_app_states_table_sql(self) -> str: + return f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._app_state_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """ - return [ - EventRecord( - session_id=row[0], - invocation_id=_oracle_text_value(row[1]), - author=_oracle_text_value(row[2]), - timestamp=row[3], - event_data=await self._deserialize_json_field(row[4]) or {}, - ) - for row in rows - ] - except oracledb.DatabaseError as e: - error_obj = e.args[0] if e.args else None - if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: - return [] - raise + def _get_drop_user_states_table_sql(self) -> str: + return f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._user_state_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """ - async def delete_expired_events(self, before: "datetime") -> int: - sql = f"DELETE FROM {self._events_table} WHERE timestamp < :before" + def _get_drop_metadata_table_sql(self) -> str: + return f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._metadata_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """ - try: - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, {"before": before}) - await conn.commit() - return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - except oracledb.DatabaseError as e: - error_obj = e.args[0] if e.args else None - if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: - return 0 - raise + def _get_drop_tables_sql(self) -> "list[str]": + """Get Oracle DROP TABLE SQL statements. - async def delete_idle_sessions(self, updated_before: "datetime") -> int: - sql = f"DELETE FROM {self._session_table} WHERE update_time < :updated_before" + Returns: + List of SQL statements to drop tables and indexes. - try: - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, {"updated_before": updated_before}) - await conn.commit() - return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - except oracledb.DatabaseError as e: - error_obj = e.args[0] if e.args else None - if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: - return 0 - raise + Notes: + Order matters: drop events table (child) before sessions (parent). + Oracle automatically drops indexes when dropping tables. + """ + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP INDEX idx_{self._events_table}_session'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -1418 THEN + RAISE; + END IF; + END; + """, + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP INDEX idx_{self._session_table}_update_time'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -1418 THEN + RAISE; + END IF; + END; + """, + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP INDEX idx_{self._session_table}_app_user'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -1418 THEN + RAISE; + END IF; + END; + """, + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._events_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """, + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._session_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """, + ] class OracleSyncADKStore(BaseAsyncADKStore["OracleSyncConfig"]): @@ -876,8 +1106,8 @@ def __init__(self, config: "OracleSyncConfig") -> None: Notes: Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") + - session_table: Sessions table name (default: "adk_session") + - events_table: Events table name (default: "adk_event") - owner_id_column: Optional owner FK column DDL (default: None) - in_memory: Enable INMEMORY PRIORITY HIGH clause (default: False) """ @@ -888,6 +1118,82 @@ def __init__(self, config: "OracleSyncConfig") -> None: adk_config = config.extension_config.get("adk", {}) self._in_memory: bool = bool(adk_config.get("in_memory", False)) + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id, renew_for) + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> SessionRecord: + """Atomically append an event and update the session's durable state.""" + return await async_(self._append_event_and_update_state)(event_record, session_id, state) + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return await async_(self._get_app_state)(app_name) + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return await async_(self._get_user_state)(app_name, user_id) + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + await async_(self._upsert_app_state)(app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + await async_(self._upsert_user_state)(app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return await async_(self._get_metadata)(key) + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + await async_(self._set_metadata)(key, value) + async def _get_create_sessions_table_sql(self) -> str: """Get Oracle CREATE TABLE SQL for sessions table. @@ -906,6 +1212,45 @@ async def _get_create_events_table_sql(self) -> str: storage_type = self._detect_json_storage_type() return self._get_create_events_table_sql_for_type(storage_type) + async def _get_create_app_states_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for app-scoped state.""" + storage_type = self._detect_json_storage_type() + return self._get_create_app_states_table_sql_for_type(storage_type) + + async def _get_create_user_states_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for user-scoped state.""" + storage_type = self._detect_json_storage_type() + return self._get_create_user_states_table_sql_for_type(storage_type) + + async def _get_create_metadata_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for ADK internal metadata.""" + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._metadata_table} ( + key VARCHAR2(128) PRIMARY KEY, + value VARCHAR2(512) NOT NULL + )'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + + async def _get_seed_metadata_sql(self) -> str: + """Get Oracle SQL to seed the ADK schema-version metadata row.""" + return f""" + BEGIN + INSERT INTO {self._metadata_table} (key, value) + SELECT 'schema_version', '1' + FROM DUAL + WHERE NOT EXISTS ( + SELECT 1 FROM {self._metadata_table} WHERE key = 'schema_version' + ); + END; + """ + def _detect_json_storage_type(self) -> JSONStorageType: """Detect the appropriate JSON storage type based on Oracle version. @@ -1025,7 +1370,7 @@ def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) storage_type: JSON storage type to use. Returns: - SQL statement to create adk_sessions table. + SQL statement to create adk_session table. """ if storage_type == JSONStorageType.JSON_NATIVE: state_column = "state JSON NOT NULL" @@ -1092,7 +1437,7 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - storage_type: JSON storage type to use. Returns: - SQL statement to create adk_events table. + SQL statement to create adk_event table. """ event_data_col = _event_data_column_ddl(storage_type) table_clauses = _oracle_table_feature_clauses( @@ -1132,6 +1477,82 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - END; """ + def _get_create_app_states_table_sql_for_type(self, storage_type: JSONStorageType) -> str: + """Get Oracle CREATE TABLE SQL for app-scoped state with specified storage type.""" + state_column = _json_column_ddl("state", storage_type) + + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._app_state_table} ( + app_name VARCHAR2(128) PRIMARY KEY, + {state_column}, + update_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL + )'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + + def _get_create_user_states_table_sql_for_type(self, storage_type: JSONStorageType) -> str: + """Get Oracle CREATE TABLE SQL for user-scoped state with specified storage type.""" + state_column = _json_column_ddl("state", storage_type) + + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._user_state_table} ( + app_name VARCHAR2(128) NOT NULL, + user_id VARCHAR2(128) NOT NULL, + {state_column}, + update_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, + PRIMARY KEY (app_name, user_id) + )'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + + def _get_drop_app_states_table_sql(self) -> str: + return f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._app_state_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """ + + def _get_drop_user_states_table_sql(self) -> str: + return f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._user_state_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """ + + def _get_drop_metadata_table_sql(self) -> str: + return f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._metadata_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """ + def _get_drop_tables_sql(self) -> "list[str]": """Get Oracle DROP TABLE SQL statements. @@ -1143,6 +1564,9 @@ def _get_drop_tables_sql(self) -> "list[str]": Oracle automatically drops indexes when dropping tables. """ return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), f""" BEGIN EXECUTE IMMEDIATE 'DROP INDEX idx_{self._events_table}_session'; @@ -1211,10 +1635,10 @@ def _create_tables(self) -> None: events_sql = SQL(self._get_create_events_table_sql_for_type(storage_type)) driver.execute_script(events_sql) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(SQL(self._get_create_app_states_table_sql_for_type(storage_type))) + driver.execute_script(SQL(self._get_create_user_states_table_sql_for_type(storage_type))) + driver.execute_script(SQL(run_(self._get_create_metadata_table_sql)())) + driver.execute_script(SQL(run_(self._get_seed_metadata_sql)())) def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -1268,12 +1692,6 @@ def _create_session( raise RuntimeError(msg) return result - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": """Get session by ID. @@ -1329,12 +1747,6 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No return None raise - async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None - ) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Update session state. @@ -1360,10 +1772,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non cursor.execute(sql, {"state": state_data, "id": session_id}) conn.commit() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state.""" - await async_(self._update_session_state)(session_id, state) - def _delete_session(self, session_id: str) -> None: """Delete session and all associated events (cascade). @@ -1380,10 +1788,6 @@ def _delete_session(self, session_id: str) -> None: cursor.execute(sql, {"id": session_id}) conn.commit() - async def delete_session(self, session_id: str) -> None: - """Delete session and associated events.""" - await async_(self._delete_session)(session_id) - def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app, optionally filtered by user. @@ -1443,10 +1847,6 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) - def _append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> SessionRecord: @@ -1515,12 +1915,6 @@ def _append_event_and_update_state( update_time=update_time, ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -1573,12 +1967,6 @@ def _get_events( return [] raise - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) - def _delete_expired_events(self, before: "datetime") -> int: sql = f"DELETE FROM {self._events_table} WHERE timestamp < :before" @@ -1594,10 +1982,6 @@ def _delete_expired_events(self, before: "datetime") -> int: return 0 raise - async def delete_expired_events(self, before: "datetime") -> int: - """Delete events older than the given timestamp.""" - return await async_(self._delete_expired_events)(before) - def _delete_idle_sessions(self, updated_before: "datetime") -> int: sql = f"DELETE FROM {self._session_table} WHERE update_time < :updated_before" @@ -1613,38 +1997,136 @@ def _delete_idle_sessions(self, updated_before: "datetime") -> int: return 0 raise - async def delete_idle_sessions(self, updated_before: "datetime") -> int: - """Delete sessions whose update_time predates the given threshold.""" - return await async_(self._delete_idle_sessions)(updated_before) + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + sql = f""" + INSERT INTO {self._events_table} ( + session_id, invocation_id, author, timestamp, event_data + ) VALUES ( + :session_id, :invocation_id, :author, :timestamp, :event_data + ) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute( + sql, + { + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "author": event_record["author"], + "timestamp": event_record["timestamp"], + "event_data": self._serialize_event_data(event_record["event_data"]), + }, + ) + conn.commit() + + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Synchronous implementation of get_app_state.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = :app_name" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"app_name": app_name}) + row = cursor.fetchone() + return self._deserialize_state(row[0]) if row is not None else None + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return None + raise + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Synchronous implementation of get_user_state.""" + sql = f""" + SELECT state + FROM {self._user_state_table} + WHERE app_name = :app_name AND user_id = :user_id + """ + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"app_name": app_name, "user_id": user_id}) + row = cursor.fetchone() + return self._deserialize_state(row[0]) if row is not None else None + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return None + raise + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Synchronous implementation of upsert_app_state.""" + sql = f""" + MERGE INTO {self._app_state_table} target + USING (SELECT :app_name AS app_name, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, state, update_time) + VALUES (source.app_name, source.state, SYSTIMESTAMP) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"app_name": app_name, "state": self._serialize_state(state)}) + conn.commit() + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Synchronous implementation of upsert_user_state.""" + sql = f""" + MERGE INTO {self._user_state_table} target + USING (SELECT :app_name AS app_name, :user_id AS user_id, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name AND target.user_id = source.user_id) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, user_id, state, update_time) + VALUES (source.app_name, source.user_id, source.state, SYSTIMESTAMP) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"app_name": app_name, "user_id": user_id, "state": self._serialize_state(state)}) + conn.commit() + + def _get_metadata(self, key: str) -> "str | None": + """Synchronous implementation of get_metadata.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE key = :key" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"key": key}) + row = cursor.fetchone() + return str(row[0]) if row is not None else None + except oracledb.DatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return None + raise - def _append_event(self, event_record: EventRecord) -> None: - """Synchronous implementation of append_event.""" + def _set_metadata(self, key: str, value: str) -> None: + """Synchronous implementation of set_metadata.""" sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data - ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_data - ) + MERGE INTO {self._metadata_table} target + USING (SELECT :key AS key, :value AS value FROM DUAL) source + ON (target.key = source.key) + WHEN MATCHED THEN + UPDATE SET target.value = source.value + WHEN NOT MATCHED THEN + INSERT (key, value) + VALUES (source.key, source.value) """ with self._config.provide_connection() as conn: cursor = conn.cursor() - cursor.execute( - sql, - { - "session_id": event_record["session_id"], - "invocation_id": event_record["invocation_id"], - "author": event_record["author"], - "timestamp": event_record["timestamp"], - "event_data": self._serialize_event_data(event_record["event_data"]), - }, - ) + cursor.execute(sql, {"key": key, "value": value}) conn.commit() - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) - class OracleAsyncADKMemoryStore(BaseAsyncADKMemoryStore["OracleAsyncConfig"]): """Oracle ADK memory store using async oracledb driver.""" @@ -1658,6 +2140,98 @@ def __init__(self, config: "OracleAsyncConfig") -> None: adk_config = config.extension_config.get("adk", {}) self._in_memory: bool = bool(adk_config.get("in_memory", False)) + async def create_tables(self) -> None: + if not self._enabled: + return + + async with self._config.provide_session() as driver: + await driver.execute_script(await self._get_create_memory_table_sql()) + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + if not self._enabled: + msg = "Memory store is disabled" + raise RuntimeError(msg) + + if not entries: + return 0 + + owner_column = f", {self._owner_id_column_name}" if self._owner_id_column_name else "" + owner_param = ", :owner_id" if self._owner_id_column_name else "" + sql = f""" + INSERT INTO {self._memory_table} ( + id, session_id, app_name, user_id, event_id, author{owner_column}, + timestamp, content_json, content_text, metadata_json, inserted_at + ) VALUES ( + :id, :session_id, :app_name, :user_id, :event_id, :author{owner_param}, + :timestamp, :content_json, :content_text, :metadata_json, :inserted_at + ) + """ + + inserted_count = 0 + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + for entry in entries: + content_json = await self._serialize_json_field(entry["content_json"]) + metadata_json = await self._serialize_json_field(entry["metadata_json"]) + params = { + "id": entry["id"], + "session_id": entry["session_id"], + "app_name": entry["app_name"], + "user_id": entry["user_id"], + "event_id": entry["event_id"], + "author": entry["author"], + "timestamp": entry["timestamp"], + "content_json": content_json, + "content_text": entry["content_text"], + "metadata_json": metadata_json, + "inserted_at": entry["inserted_at"], + } + if self._owner_id_column_name: + params["owner_id"] = str(owner_id) if owner_id is not None else None + if await self._execute_insert_entry(cursor, sql, params): + inserted_count += 1 + await conn.commit() + + return inserted_count + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + if not self._enabled: + msg = "Memory store is disabled" + raise RuntimeError(msg) + + effective_limit = limit if limit is not None else self._max_results + + try: + if self._use_fts: + return await self._search_entries_fts(query, app_name, user_id, effective_limit) + return await self._search_entries_simple(query, app_name, user_id, effective_limit) + except oracledb.DatabaseError as exc: + error_obj = exc.args[0] if exc.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return [] + raise + + async def delete_entries_by_session(self, session_id: str) -> int: + sql = f"DELETE FROM {self._memory_table} WHERE session_id = :session_id" + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"session_id": session_id}) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + + async def delete_entries_older_than(self, days: int) -> int: + sql = f""" + DELETE FROM {self._memory_table} + WHERE inserted_at < SYSTIMESTAMP - NUMTODSINTERVAL(:days, 'DAY') + """ + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"days": days}) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + async def _detect_json_storage_type(self) -> "JSONStorageType": if self._json_storage_type is not None: return self._json_storage_type @@ -1818,13 +2392,6 @@ def _get_drop_memory_table_sql(self) -> "list[str]": """, ] - async def create_tables(self) -> None: - if not self._enabled: - return - - async with self._config.provide_session() as driver: - await driver.execute_script(await self._get_create_memory_table_sql()) - async def _execute_insert_entry(self, cursor: Any, sql: str, params: "dict[str, Any]") -> bool: """Execute an insert and skip duplicate key errors.""" try: @@ -1836,72 +2403,6 @@ async def _execute_insert_entry(self, cursor: Any, sql: str, params: "dict[str, raise return True - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - if not self._enabled: - msg = "Memory store is disabled" - raise RuntimeError(msg) - - if not entries: - return 0 - - owner_column = f", {self._owner_id_column_name}" if self._owner_id_column_name else "" - owner_param = ", :owner_id" if self._owner_id_column_name else "" - sql = f""" - INSERT INTO {self._memory_table} ( - id, session_id, app_name, user_id, event_id, author{owner_column}, - timestamp, content_json, content_text, metadata_json, inserted_at - ) VALUES ( - :id, :session_id, :app_name, :user_id, :event_id, :author{owner_param}, - :timestamp, :content_json, :content_text, :metadata_json, :inserted_at - ) - """ - - inserted_count = 0 - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - for entry in entries: - content_json = await self._serialize_json_field(entry["content_json"]) - metadata_json = await self._serialize_json_field(entry["metadata_json"]) - params = { - "id": entry["id"], - "session_id": entry["session_id"], - "app_name": entry["app_name"], - "user_id": entry["user_id"], - "event_id": entry["event_id"], - "author": entry["author"], - "timestamp": entry["timestamp"], - "content_json": content_json, - "content_text": entry["content_text"], - "metadata_json": metadata_json, - "inserted_at": entry["inserted_at"], - } - if self._owner_id_column_name: - params["owner_id"] = str(owner_id) if owner_id is not None else None - if await self._execute_insert_entry(cursor, sql, params): - inserted_count += 1 - await conn.commit() - - return inserted_count - - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - if not self._enabled: - msg = "Memory store is disabled" - raise RuntimeError(msg) - - effective_limit = limit if limit is not None else self._max_results - - try: - if self._use_fts: - return await self._search_entries_fts(query, app_name, user_id, effective_limit) - return await self._search_entries_simple(query, app_name, user_id, effective_limit) - except oracledb.DatabaseError as exc: - error_obj = exc.args[0] if exc.args else None - if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: - return [] - raise - async def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT id, session_id, app_name, user_id, event_id, author, @@ -1948,25 +2449,6 @@ async def _search_entries_simple(self, query: str, app_name: str, user_id: str, rows = await cursor.fetchall() return await self._rows_to_records(rows) - async def delete_entries_by_session(self, session_id: str) -> int: - sql = f"DELETE FROM {self._memory_table} WHERE session_id = :session_id" - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, {"session_id": session_id}) - await conn.commit() - return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - - async def delete_entries_older_than(self, days: int) -> int: - sql = f""" - DELETE FROM {self._memory_table} - WHERE inserted_at < SYSTIMESTAMP - NUMTODSINTERVAL(:days, 'DAY') - """ - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, {"days": days}) - await conn.commit() - return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - async def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": records: list[MemoryRecord] = [] for row in rows: @@ -2003,6 +2485,28 @@ def __init__(self, config: "OracleSyncConfig") -> None: adk_config = config.extension_config.get("adk", {}) self._in_memory = bool(adk_config.get("in_memory", False)) + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + def _detect_json_storage_type(self) -> "JSONStorageType": if self._json_storage_type is not None: return self._json_storage_type @@ -2170,10 +2674,6 @@ def _create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(run_(self._get_create_memory_table_sql)()) - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - def _execute_insert_entry(self, cursor: Any, sql: str, params: "dict[str, Any]") -> bool: """Execute an insert and skip duplicate key errors.""" try: @@ -2232,10 +2732,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -2255,12 +2751,6 @@ def _search_entries( return [] raise - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT id, session_id, app_name, user_id, event_id, author, @@ -2315,10 +2805,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: sql = f""" DELETE FROM {self._memory_table} @@ -2330,10 +2816,6 @@ def _delete_entries_older_than(self, days: int) -> int: conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) - def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": records: list[MemoryRecord] = [] for row in rows: @@ -2422,6 +2904,15 @@ def _event_data_column_ddl(storage_type: JSONStorageType) -> str: return "event_data BLOB NOT NULL" +def _json_column_ddl(column_name: str, storage_type: JSONStorageType) -> str: + """Return an Oracle JSON column DDL fragment for the configured storage type.""" + if storage_type == JSONStorageType.JSON_NATIVE: + return f"{column_name} JSON NOT NULL" + if storage_type == JSONStorageType.BLOB_JSON: + return f"{column_name} BLOB CHECK ({column_name} IS JSON) NOT NULL" + return f"{column_name} BLOB NOT NULL" + + def _get_oracle_adk_config(config: Any) -> dict[str, Any]: adk_config = config.extension_config.get("adk", {}) if isinstance(adk_config, dict): diff --git a/sqlspec/adapters/psqlpy/adk/store.py b/sqlspec/adapters/psqlpy/adk/store.py index 8cc83fcaa..e33b3dd86 100644 --- a/sqlspec/adapters/psqlpy/adk/store.py +++ b/sqlspec/adapters/psqlpy/adk/store.py @@ -22,6 +22,7 @@ logger = get_logger("sqlspec.adapters.psqlpy.adk.store") POSTGRES_TABLE_NOT_FOUND_SQLSTATE: Final = "42P01" +PSQLPY_STATUS_REGEX: Final[re.Pattern[str]] = re.compile(r"^([A-Z]+)(?:\s+(\d+))?\s+(\d+)$", re.IGNORECASE) class PsqlpyADKStore(BaseAsyncADKStore["PsqlpyConfig"]): @@ -50,54 +51,14 @@ class PsqlpyADKStore(BaseAsyncADKStore["PsqlpyConfig"]): def __init__(self, config: "PsqlpyConfig") -> None: super().__init__(config) - async def _get_create_sessions_table_sql(self) -> str: - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL{owner_id_line}, - state JSONB NOT NULL DEFAULT '{{}}'::jsonb, - create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ) WITH (fillfactor = 80); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state - ON {self._session_table} USING GIN (state) - WHERE state != '{{}}'::jsonb; - """ - - async def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_data JSONB NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ) WITH (fillfactor = 80); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC); - """ - - def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -354,58 +315,185 @@ async def delete_idle_sessions(self, updated_before: "datetime") -> int: return 0 raise + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = $1" -PSQLPY_STATUS_REGEX: Final[re.Pattern[str]] = re.compile(r"^([A-Z]+)(?:\s+(\d+))?\s+(\d+)$", re.IGNORECASE) + try: + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + result = await conn.fetch(sql, [app_name]) + rows: list[dict[str, Any]] = result.result() if result else [] + return rows[0]["state"] if rows else None + except psqlpy.exceptions.DatabaseError as e: + error_msg = str(e).lower() + if "does not exist" in error_msg or "relation" in error_msg: + return None + raise + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = $1 AND user_id = $2" -class PsqlpyADKMemoryStore(BaseAsyncADKMemoryStore["PsqlpyConfig"]): - """PostgreSQL ADK memory store using Psqlpy driver.""" + try: + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + result = await conn.fetch(sql, [app_name, user_id]) + rows: list[dict[str, Any]] = result.result() if result else [] + return rows[0]["state"] if rows else None + except psqlpy.exceptions.DatabaseError as e: + error_msg = str(e).lower() + if "does not exist" in error_msg or "relation" in error_msg: + return None + raise - __slots__ = () + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES ($1, $2, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ - def __init__(self, config: "PsqlpyConfig") -> None: - """Initialize Psqlpy memory store.""" - super().__init__(config) + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + await conn.execute(sql, [app_name, state]) - async def _get_create_memory_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for memory entries.""" + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES ($1, $2, $3, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + await conn.execute(sql, [app_name, user_id, state]) + + async def get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE key = $1" + + try: + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + result = await conn.fetch(sql, [key]) + rows: list[dict[str, Any]] = result.result() if result else [] + return rows[0]["value"] if rows else None + except psqlpy.exceptions.DatabaseError as e: + error_msg = str(e).lower() + if "does not exist" in error_msg or "relation" in error_msg: + return None + raise + + async def set_metadata(self, key: str, value: str) -> None: + sql = f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ($1, $2) + ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value + """ + + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + await conn.execute(sql, [key, value]) + + async def _get_create_sessions_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" - fts_index = "" - if self._use_fts: - fts_index = f""" - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts - ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); - """ - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( + CREATE TABLE IF NOT EXISTS {self._session_table} ( id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL{owner_id_line}, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; + """ + + async def _get_create_events_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_data JSONB NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session + ON {self._events_table}(session_id, timestamp ASC); + """ + + async def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + """ + + async def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( app_name VARCHAR(128) NOT NULL, user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, - timestamp TIMESTAMPTZ NOT NULL, - content_json JSONB NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSONB, - inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ); + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ) WITH (fillfactor = 80); + """ - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time - ON {self._memory_table}(app_name, user_id, timestamp DESC); + async def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session - ON {self._memory_table}(session_id); - {fts_index} + async def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING """ - def _get_drop_memory_table_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements.""" - return [f"DROP TABLE IF EXISTS {self._memory_table}"] + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + + +class PsqlpyADKMemoryStore(BaseAsyncADKMemoryStore["PsqlpyConfig"]): + """PostgreSQL ADK memory store using Psqlpy driver.""" + + __slots__ = () + + def __init__(self, config: "PsqlpyConfig") -> None: + """Initialize Psqlpy memory store.""" + super().__init__(config) async def create_tables(self) -> None: """Create the memory table and indexes if they don't exist.""" @@ -509,42 +597,6 @@ async def search_entries( return [] raise - async def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": - sql = f""" - SELECT id, session_id, app_name, user_id, event_id, author, - timestamp, content_json, content_text, metadata_json, inserted_at, - ts_rank(to_tsvector('english', content_text), plainto_tsquery('english', $1)) as rank - FROM {self._memory_table} - WHERE app_name = $2 - AND user_id = $3 - AND to_tsvector('english', content_text) @@ plainto_tsquery('english', $1) - ORDER BY rank DESC, timestamp DESC - LIMIT $4 - """ - params = [query, app_name, user_id, limit] - async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] - result = await conn.fetch(sql, params) - rows: list[dict[str, Any]] = result.result() if result else [] - return _rows_to_records(rows) - - async def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": - sql = f""" - SELECT id, session_id, app_name, user_id, event_id, author, - timestamp, content_json, content_text, metadata_json, inserted_at - FROM {self._memory_table} - WHERE app_name = $1 - AND user_id = $2 - AND content_text ILIKE $3 - ORDER BY timestamp DESC - LIMIT $4 - """ - pattern = f"%{query}%" - params = [app_name, user_id, pattern, limit] - async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] - result = await conn.fetch(sql, params) - rows: list[dict[str, Any]] = result.result() if result else [] - return _rows_to_records(rows) - async def delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session.""" count_sql = f"SELECT COUNT(*) AS count FROM {self._memory_table} WHERE session_id = $1" @@ -587,6 +639,82 @@ async def delete_entries_older_than(self, days: int) -> int: return 0 raise + async def _get_create_memory_table_sql(self) -> str: + """Get PostgreSQL CREATE TABLE SQL for memory entries.""" + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMPTZ NOT NULL, + content_json JSONB NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSONB, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time + ON {self._memory_table}(app_name, user_id, timestamp DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session + ON {self._memory_table}(session_id); + {fts_index} + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + """Get PostgreSQL DROP TABLE SQL statements.""" + return [f"DROP TABLE IF EXISTS {self._memory_table}"] + + async def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": + sql = f""" + SELECT id, session_id, app_name, user_id, event_id, author, + timestamp, content_json, content_text, metadata_json, inserted_at, + ts_rank(to_tsvector('english', content_text), plainto_tsquery('english', $1)) as rank + FROM {self._memory_table} + WHERE app_name = $2 + AND user_id = $3 + AND to_tsvector('english', content_text) @@ plainto_tsquery('english', $1) + ORDER BY rank DESC, timestamp DESC + LIMIT $4 + """ + params = [query, app_name, user_id, limit] + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + result = await conn.fetch(sql, params) + rows: list[dict[str, Any]] = result.result() if result else [] + return _rows_to_records(rows) + + async def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": + sql = f""" + SELECT id, session_id, app_name, user_id, event_id, author, + timestamp, content_json, content_text, metadata_json, inserted_at + FROM {self._memory_table} + WHERE app_name = $1 + AND user_id = $2 + AND content_text ILIKE $3 + ORDER BY timestamp DESC + LIMIT $4 + """ + pattern = f"%{query}%" + params = [app_name, user_id, pattern, limit] + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + result = await conn.fetch(sql, params) + rows: list[dict[str, Any]] = result.result() if result else [] + return _rows_to_records(rows) + def _extract_rows_affected(self, result: Any) -> int: """Extract rows affected from psqlpy result.""" try: diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py index 0b4025d74..ee6c414a1 100644 --- a/sqlspec/adapters/psycopg/adk/store.py +++ b/sqlspec/adapters/psycopg/adk/store.py @@ -82,54 +82,14 @@ class PsycopgAsyncADKStore(BaseAsyncADKStore["PsycopgAsyncConfig"]): def __init__(self, config: "PsycopgAsyncConfig") -> None: super().__init__(config) - async def _get_create_sessions_table_sql(self) -> str: - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL{owner_id_line}, - state JSONB NOT NULL DEFAULT '{{}}'::jsonb, - create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ) WITH (fillfactor = 80); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state - ON {self._session_table} USING GIN (state) - WHERE state != '{{}}'::jsonb; - """ - - async def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_data JSONB NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ) WITH (fillfactor = 80); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC); - """ - - def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -384,6 +344,173 @@ async def delete_idle_sessions(self, updated_before: "datetime") -> int: except errors.UndefinedTable: return 0 + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + query = pg_sql.SQL("SELECT state FROM {table} WHERE app_name = %s").format( + table=pg_sql.Identifier(self._app_state_table) + ) + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(query, (app_name,)) + row = await cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + query = pg_sql.SQL("SELECT state FROM {table} WHERE app_name = %s AND user_id = %s").format( + table=pg_sql.Identifier(self._user_state_table) + ) + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(query, (app_name, user_id)) + row = await cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._app_state_table)) + + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(query, (app_name, Jsonb(state))) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._user_state_table)) + + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(query, (app_name, user_id, Jsonb(state))) + + async def get_metadata(self, key: str) -> "str | None": + query = pg_sql.SQL("SELECT value FROM {table} WHERE key = %s").format( + table=pg_sql.Identifier(self._metadata_table) + ) + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(query, (key,)) + row = await cur.fetchone() + return row["value"] if row is not None else None + except errors.UndefinedTable: + return None + + async def set_metadata(self, key: str, value: str) -> None: + query = pg_sql.SQL(""" + INSERT INTO {table} (key, value) + VALUES (%s, %s) + ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value + """).format(table=pg_sql.Identifier(self._metadata_table)) + + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(query, (key, value)) + + async def _get_create_sessions_table_sql(self) -> str: + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + return f""" + CREATE TABLE IF NOT EXISTS {self._session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL{owner_id_line}, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; + """ + + async def _get_create_events_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256) NOT NULL, + author VARCHAR(256) NOT NULL, + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_data JSONB NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session + ON {self._events_table}(session_id, timestamp ASC); + """ + + async def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + """ + + async def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ) WITH (fillfactor = 80); + """ + + async def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ + + async def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING + """ + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + class PsycopgSyncADKStore(BaseAsyncADKStore["PsycopgSyncConfig"]): """PostgreSQL synchronous ADK store using Psycopg3 driver. @@ -411,6 +538,82 @@ class PsycopgSyncADKStore(BaseAsyncADKStore["PsycopgSyncConfig"]): def __init__(self, config: "PsycopgSyncConfig") -> None: super().__init__(config) + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id, renew_for) + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> SessionRecord: + """Atomically append an event and update the session's durable state.""" + return await async_(self._append_event_and_update_state)(event_record, session_id, state) + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return await async_(self._get_app_state)(app_name) + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return await async_(self._get_user_state)(app_name, user_id) + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + await async_(self._upsert_app_state)(app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + await async_(self._upsert_user_state)(app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return await async_(self._get_metadata)(key) + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + await async_(self._set_metadata)(key, value) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + async def _get_create_sessions_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: @@ -452,17 +655,67 @@ async def _get_create_events_table_sql(self) -> str: ON {self._events_table}(session_id, timestamp ASC); """ + async def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + """ + + async def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ) WITH (fillfactor = 80); + """ + + async def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ + + async def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING + """ + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] def _create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(run_(self._get_create_sessions_table_sql)()) driver.execute_script(run_(self._get_create_events_table_sql)()) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(run_(self._get_create_app_states_table_sql)()) + driver.execute_script(run_(self._get_create_user_states_table_sql)()) + driver.execute_script(run_(self._get_create_metadata_table_sql)()) + driver.execute_script(run_(self._get_seed_metadata_sql)()) def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -492,12 +745,6 @@ def _create_session( raise RuntimeError(msg) return result - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": if renew_for is not None and self._calculate_expires_at(renew_for) is not None: query = pg_sql.SQL(""" @@ -532,12 +779,6 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No except errors.UndefinedTable: return None - async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None - ) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: query = pg_sql.SQL(""" UPDATE {table} @@ -548,20 +789,12 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(query, (Jsonb(state), session_id)) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state.""" - await async_(self._update_session_state)(session_id, state) - def _delete_session(self, session_id: str) -> None: query = pg_sql.SQL("DELETE FROM {table} WHERE id = %s").format(table=pg_sql.Identifier(self._session_table)) with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(query, (session_id,)) - async def delete_session(self, session_id: str) -> None: - """Delete session and associated events.""" - await async_(self._delete_session)(session_id) - def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: query = pg_sql.SQL(""" @@ -599,10 +832,6 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses except errors.UndefinedTable: return [] - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) - def _insert_event(self, event_record: EventRecord) -> None: insert_query = pg_sql.SQL(""" INSERT INTO {table} ( @@ -673,12 +902,6 @@ def _append_event_and_update_state( update_time=row["update_time"], ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -724,12 +947,6 @@ def _get_events( except errors.UndefinedTable: return [] - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) - def _delete_expired_events(self, before: "datetime") -> int: query = pg_sql.SQL("DELETE FROM {table} WHERE timestamp < %s").format( table=pg_sql.Identifier(self._events_table) @@ -743,10 +960,6 @@ def _delete_expired_events(self, before: "datetime") -> int: except errors.UndefinedTable: return 0 - async def delete_expired_events(self, before: "datetime") -> int: - """Delete events older than the given timestamp.""" - return await async_(self._delete_expired_events)(before) - def _delete_idle_sessions(self, updated_before: "datetime") -> int: query = pg_sql.SQL("DELETE FROM {table} WHERE update_time < %s").format( table=pg_sql.Identifier(self._session_table) @@ -760,18 +973,86 @@ def _delete_idle_sessions(self, updated_before: "datetime") -> int: except errors.UndefinedTable: return 0 - async def delete_idle_sessions(self, updated_before: "datetime") -> int: - """Delete sessions whose update_time predates the given threshold.""" - return await async_(self._delete_idle_sessions)(updated_before) + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + query = pg_sql.SQL("SELECT state FROM {table} WHERE app_name = %s").format( + table=pg_sql.Identifier(self._app_state_table) + ) + + try: + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(query, (app_name,)) + row = cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + query = pg_sql.SQL("SELECT state FROM {table} WHERE app_name = %s AND user_id = %s").format( + table=pg_sql.Identifier(self._user_state_table) + ) + + try: + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(query, (app_name, user_id)) + row = cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._app_state_table)) + + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(query, (app_name, Jsonb(state))) + conn.commit() + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._user_state_table)) + + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(query, (app_name, user_id, Jsonb(state))) + conn.commit() + + def _get_metadata(self, key: str) -> "str | None": + query = pg_sql.SQL("SELECT value FROM {table} WHERE key = %s").format( + table=pg_sql.Identifier(self._metadata_table) + ) + + try: + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(query, (key,)) + row = cur.fetchone() + return row["value"] if row is not None else None + except errors.UndefinedTable: + return None + + def _set_metadata(self, key: str, value: str) -> None: + query = pg_sql.SQL(""" + INSERT INTO {table} (key, value) + VALUES (%s, %s) + ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value + """).format(table=pg_sql.Identifier(self._metadata_table)) + + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(query, (key, value)) + conn.commit() def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) - class PsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["PsycopgAsyncConfig"]): """PostgreSQL ADK memory store using Psycopg3 async driver.""" @@ -782,46 +1063,6 @@ def __init__(self, config: "PsycopgAsyncConfig") -> None: """Initialize Psycopg async memory store.""" super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for memory entries.""" - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - fts_index = "" - if self._use_fts: - fts_index = f""" - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts - ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); - """ - - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( - id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, - timestamp TIMESTAMPTZ NOT NULL, - content_json JSONB NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSONB, - inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ); - - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time - ON {self._memory_table}(app_name, user_id, timestamp DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session - ON {self._memory_table}(session_id); - {fts_index} - """ - - def _get_drop_memory_table_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements.""" - return [f"DROP TABLE IF EXISTS {self._memory_table}"] - async def create_tables(self) -> None: """Create the memory table and indexes if they don't exist.""" if not self._enabled: @@ -895,6 +1136,69 @@ async def search_entries( except errors.UndefinedTable: return [] + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + sql = pg_sql.SQL("DELETE FROM {table} WHERE session_id = %s").format( + table=pg_sql.Identifier(self._memory_table) + ) + + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(sql, (session_id,)) + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + sql = pg_sql.SQL( + """ + DELETE FROM {table} + WHERE inserted_at < CURRENT_TIMESTAMP - {interval}::interval + """ + ).format(table=pg_sql.Identifier(self._memory_table), interval=pg_sql.Literal(f"{days} days")) + + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(sql) + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + + async def _get_create_memory_table_sql(self) -> str: + """Get PostgreSQL CREATE TABLE SQL for memory entries.""" + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMPTZ NOT NULL, + content_json JSONB NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSONB, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time + ON {self._memory_table}(app_name, user_id, timestamp DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session + ON {self._memory_table}(session_id); + {fts_index} + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + """Get PostgreSQL DROP TABLE SQL statements.""" + return [f"DROP TABLE IF EXISTS {self._memory_table}"] + async def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = pg_sql.SQL( """ @@ -935,29 +1239,6 @@ async def _search_entries_simple(self, query: str, app_name: str, user_id: str, rows = await cur.fetchall() return _rows_to_records(rows) - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - sql = pg_sql.SQL("DELETE FROM {table} WHERE session_id = %s").format( - table=pg_sql.Identifier(self._memory_table) - ) - - async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql, (session_id,)) - return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - sql = pg_sql.SQL( - """ - DELETE FROM {table} - WHERE inserted_at < CURRENT_TIMESTAMP - {interval}::interval - """ - ).format(table=pg_sql.Identifier(self._memory_table), interval=pg_sql.Literal(f"{days} days")) - - async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql) - return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - class PsycopgSyncADKMemoryStore(BaseAsyncADKMemoryStore["PsycopgSyncConfig"]): """PostgreSQL ADK memory store using Psycopg3 sync driver.""" @@ -968,6 +1249,28 @@ def __init__(self, config: "PsycopgSyncConfig") -> None: """Initialize Psycopg sync memory store.""" super().__init__(config) + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + async def _get_create_memory_table_sql(self) -> str: """Get PostgreSQL CREATE TABLE SQL for memory entries.""" owner_id_line = "" @@ -1016,10 +1319,6 @@ def _create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(run_(self._get_create_memory_table_sql)()) - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication.""" if not self._enabled: @@ -1065,10 +1364,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -1089,12 +1384,6 @@ def _search_entries( except errors.UndefinedTable: return [] - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = pg_sql.SQL( """ @@ -1145,10 +1434,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: cur.execute(sql, (session_id,)) return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days.""" sql = pg_sql.SQL( @@ -1162,10 +1447,6 @@ def _delete_entries_older_than(self, days: int) -> int: cur.execute(sql) return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) - def _rows_to_records(rows: "list[Any]") -> "list[MemoryRecord]": return [ diff --git a/sqlspec/adapters/pymysql/adk/store.py b/sqlspec/adapters/pymysql/adk/store.py index e7fbb0340..c130aa59e 100644 --- a/sqlspec/adapters/pymysql/adk/store.py +++ b/sqlspec/adapters/pymysql/adk/store.py @@ -57,6 +57,82 @@ class PyMysqlADKStore(BaseAsyncADKStore["PyMysqlConfig"]): def __init__(self, config: "PyMysqlConfig") -> None: super().__init__(config) + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id, renew_for) + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> SessionRecord: + """Atomically append an event and update the session's durable state.""" + return await async_(self._append_event_and_update_state)(event_record, session_id, state) + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return await async_(self._get_app_state)(app_name) + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return await async_(self._get_user_state)(app_name, user_id) + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + await async_(self._upsert_app_state)(app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + await async_(self._upsert_user_state)(app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return await async_(self._get_metadata)(key) + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + await async_(self._set_metadata)(key, value) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": return _parse_owner_id_column_for_mysql(column_ddl) @@ -101,17 +177,73 @@ async def _get_create_events_table_sql(self) -> str: ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """ + async def _get_create_app_states_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for app-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + async def _get_create_user_states_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for user-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + async def _get_create_metadata_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for ADK internal metadata.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + `key` VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + async def _get_seed_metadata_sql(self) -> str: + """Get MySQL SQL that seeds the ADK schema version metadata row.""" + return f""" + INSERT IGNORE INTO {self._metadata_table} (`key`, value) + VALUES ('schema_version', '1') + """ + + def _get_drop_app_states_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for app-scoped state.""" + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for user-scoped state.""" + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for ADK internal metadata.""" + return f"DROP TABLE IF EXISTS {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] def _create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(run_(self._get_create_sessions_table_sql)()) driver.execute_script(run_(self._get_create_events_table_sql)()) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(run_(self._get_create_app_states_table_sql)()) + driver.execute_script(run_(self._get_create_user_states_table_sql)()) + driver.execute_script(run_(self._get_create_metadata_table_sql)()) + driver.execute_script(run_(self._get_seed_metadata_sql)()) def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -146,12 +278,6 @@ def _create_session( raise RuntimeError(msg) return result - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -191,12 +317,6 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No return None raise - async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None - ) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: state_json = to_json(state) @@ -214,10 +334,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non cursor.close() conn.commit() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state.""" - await async_(self._update_session_state)(session_id, state) - def _delete_session(self, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE id = %s" @@ -229,10 +345,6 @@ def _delete_session(self, session_id: str) -> None: cursor.close() conn.commit() - async def delete_session(self, session_id: str) -> None: - """Delete session and associated events.""" - await async_(self._delete_session)(session_id) - def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: sql = f""" @@ -276,10 +388,6 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) - def _append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> SessionRecord: @@ -350,12 +458,6 @@ def _append_event_and_update_state( update_time=row[5], ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _insert_event(self, event_record: EventRecord) -> None: event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data @@ -438,12 +540,6 @@ def _get_events( return [] raise - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) - def _delete_expired_events(self, before: "datetime") -> int: sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" @@ -461,10 +557,6 @@ def _delete_expired_events(self, before: "datetime") -> int: return 0 raise - async def delete_expired_events(self, before: "datetime") -> int: - """Delete events older than the given timestamp.""" - return await async_(self._delete_expired_events)(before) - def _delete_idle_sessions(self, updated_before: "datetime") -> int: sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" @@ -482,18 +574,106 @@ def _delete_idle_sessions(self, updated_before: "datetime") -> int: return 0 raise - async def delete_idle_sessions(self, updated_before: "datetime") -> int: - """Delete sessions whose update_time predates the given threshold.""" - return await async_(self._delete_idle_sessions)(updated_before) + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = %s" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (app_name,)) + row = cursor.fetchone() + finally: + cursor.close() + return from_json(row[0]) if row is not None and isinstance(row[0], str) else (row[0] if row else None) + except pymysql.MySQLError as exc: + if "doesn't exist" in str(exc) or getattr(exc, "args", [None])[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = %s AND user_id = %s" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (app_name, user_id)) + row = cursor.fetchone() + finally: + cursor.close() + return from_json(row[0]) if row is not None and isinstance(row[0], str) else (row[0] if row else None) + except pymysql.MySQLError as exc: + if "doesn't exist" in str(exc) or getattr(exc, "args", [None])[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (app_name, to_json(state))) + finally: + cursor.close() + conn.commit() + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (app_name, user_id, to_json(state))) + finally: + cursor.close() + conn.commit() + + def _get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE `key` = %s" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (key,)) + row = cursor.fetchone() + finally: + cursor.close() + return row[0] if row is not None else None + except pymysql.MySQLError as exc: + if "doesn't exist" in str(exc) or getattr(exc, "args", [None])[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise + + def _set_metadata(self, key: str, value: str) -> None: + sql = f""" + INSERT INTO {self._metadata_table} (`key`, value) + VALUES (%s, %s) + ON DUPLICATE KEY UPDATE value = VALUES(value) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (key, value)) + finally: + cursor.close() + conn.commit() def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) - class PyMysqlADKMemoryStore(BaseAsyncADKMemoryStore["PyMysqlConfig"]): """MySQL/MariaDB ADK memory store using PyMySQL.""" @@ -503,6 +683,28 @@ class PyMysqlADKMemoryStore(BaseAsyncADKMemoryStore["PyMysqlConfig"]): def __init__(self, config: "PyMysqlConfig") -> None: super().__init__(config) + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + async def _get_create_memory_table_sql(self) -> str: owner_id_line = "" fk_constraint = "" @@ -544,10 +746,6 @@ def _create_tables(self) -> None: with self._config.provide_session() as driver: driver.execute_script(run_(self._get_create_memory_table_sql)()) - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -614,10 +812,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec conn.commit() return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -658,12 +852,6 @@ def _search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _delete_entries_by_session(self, session_id: str) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -679,10 +867,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -700,7 +884,3 @@ def _delete_entries_older_than(self, days: int) -> int: return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 finally: cursor.close() - - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index 50a3233e3..ea6dad191 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -41,6 +41,82 @@ def __init__(self, config: SpannerSyncConfig) -> None: ) self._events_row_deletion_policy = _spanner_row_deletion_policy(adk_config, "event_ttl_seconds", "timestamp") + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID.""" + return await async_(self._get_session)(session_id, renew_for) + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + await async_(self._update_session_state)(session_id, state) + + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return await async_(self._list_sessions)(app_name, user_id) + + async def delete_session(self, session_id: str) -> None: + """Delete session and associated events.""" + await async_(self._delete_session)(session_id) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> SessionRecord: + """Atomically append an event and update the session's durable state.""" + return await async_(self._append_event_and_update_state)(event_record, session_id, state) + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session.""" + return await async_(self._get_events)(session_id, after_timestamp, limit) + + async def delete_expired_events(self, before: "datetime") -> int: + """Return 0 because Spanner row deletion policies own TTL cleanup.""" + return 0 + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Return 0 because Spanner row deletion policies own TTL cleanup.""" + return 0 + + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return await async_(self._get_app_state)(app_name) + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return await async_(self._get_user_state)(app_name, user_id) + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + await async_(self._upsert_app_state)(app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + await async_(self._upsert_user_state)(app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return await async_(self._get_metadata)(key) + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + await async_(self._set_metadata)(key, value) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + await async_(self._append_event)(event_record) + + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + def _database(self) -> "Database": return self._config.get_database() @@ -117,12 +193,6 @@ def _create_session( "update_time": datetime.now(timezone.utc), } - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": if renew_for is not None and self._calculate_expires_at(renew_for) is not None: update_sql = f""" @@ -159,12 +229,6 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No } return record - async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None - ) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: params = {"id": session_id, "state": to_json(state)} json_type = _json_param_type() @@ -177,10 +241,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" self._run_write([(sql, params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type})]) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state.""" - await async_(self._update_session_state)(session_id, state) - def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time{", " + self._owner_id_column_name if self._owner_id_column_name else ""} @@ -212,10 +272,6 @@ def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[S records.append(record) return records - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) - def _delete_session(self, session_id: str) -> None: shard_clause = ( f" AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})" if self._shard_count > 1 else "" @@ -226,10 +282,6 @@ def _delete_session(self, session_id: str) -> None: types = {"session_id": SPANNER_PARAM_TYPES.STRING} self._run_write([(delete_events_sql, params, types), (delete_session_sql, params, types)]) - async def delete_session(self, session_id: str) -> None: - """Delete session and associated events.""" - await async_(self._delete_session)(session_id) - def _append_event_and_update_state( self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" ) -> SessionRecord: @@ -279,12 +331,6 @@ def _append_event_and_update_state( raise ValueError(msg) return record - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _insert_event(self, event_record: "EventRecord") -> None: event_params: dict[str, Any] = { "session_id": event_record["session_id"], @@ -332,28 +378,83 @@ def _get_events( for row in rows ] - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f""" + SELECT state + FROM {self._app_state_table} + WHERE app_name = @app_name + LIMIT 1 + """ + rows = self._run_read(sql, {"app_name": app_name}, {"app_name": SPANNER_PARAM_TYPES.STRING}) + return self._decode_state(rows[0][0]) if rows else None - async def delete_expired_events(self, before: "datetime") -> int: - """Return 0 because Spanner row deletion policies own TTL cleanup.""" - return 0 + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f""" + SELECT state + FROM {self._user_state_table} + WHERE app_name = @app_name AND user_id = @user_id + LIMIT 1 + """ + params = {"app_name": app_name, "user_id": user_id} + types = {"app_name": SPANNER_PARAM_TYPES.STRING, "user_id": SPANNER_PARAM_TYPES.STRING} + rows = self._run_read(sql, params, types) + return self._decode_state(rows[0][0]) if rows else None - async def delete_idle_sessions(self, updated_before: "datetime") -> int: - """Return 0 because Spanner row deletion policies own TTL cleanup.""" - return 0 + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + delete_sql = f"DELETE FROM {self._app_state_table} WHERE app_name = @app_name" + insert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (@app_name, @state, PENDING_COMMIT_TIMESTAMP()) + """ + params = {"app_name": app_name, "state": to_json(state)} + types = {"app_name": SPANNER_PARAM_TYPES.STRING, "state": _json_param_type()} + self._run_write([ + (delete_sql, {"app_name": app_name}, {"app_name": SPANNER_PARAM_TYPES.STRING}), + (insert_sql, params, types), + ]) + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + delete_sql = f"DELETE FROM {self._user_state_table} WHERE app_name = @app_name AND user_id = @user_id" + insert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (@app_name, @user_id, @state, PENDING_COMMIT_TIMESTAMP()) + """ + params = {"app_name": app_name, "user_id": user_id, "state": to_json(state)} + types = { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "state": _json_param_type(), + } + self._run_write([ + ( + delete_sql, + {"app_name": app_name, "user_id": user_id}, + {"app_name": SPANNER_PARAM_TYPES.STRING, "user_id": SPANNER_PARAM_TYPES.STRING}, + ), + (insert_sql, params, types), + ]) + + def _get_metadata(self, key: str) -> "str | None": + sql = f""" + SELECT value + FROM {self._metadata_table} + WHERE key = @key + LIMIT 1 + """ + rows = self._run_read(sql, {"key": key}, {"key": SPANNER_PARAM_TYPES.STRING}) + return rows[0][0] if rows else None + + def _set_metadata(self, key: str, value: str) -> None: + delete_sql = f"DELETE FROM {self._metadata_table} WHERE key = @key" + insert_sql = f"INSERT INTO {self._metadata_table} (key, value) VALUES (@key, @value)" + params = {"key": key, "value": value} + types = {"key": SPANNER_PARAM_TYPES.STRING, "value": SPANNER_PARAM_TYPES.STRING} + self._run_write([(delete_sql, {"key": key}, {"key": SPANNER_PARAM_TYPES.STRING}), (insert_sql, params, types)]) def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) - def _create_tables(self) -> None: database = self._database() existing_tables = {t.table_id for t in database.list_tables()} # type: ignore[no-untyped-call] @@ -363,13 +464,16 @@ def _create_tables(self) -> None: ddl_statements.append(run_(self._get_create_sessions_table_sql)()) if self._events_table not in existing_tables: ddl_statements.append(run_(self._get_create_events_table_sql)()) + if self._app_state_table not in existing_tables: + ddl_statements.append(run_(self._get_create_app_states_table_sql)()) + if self._user_state_table not in existing_tables: + ddl_statements.append(run_(self._get_create_user_states_table_sql)()) + if self._metadata_table not in existing_tables: + ddl_statements.append(run_(self._get_create_metadata_table_sql)()) if ddl_statements: database.update_ddl(ddl_statements).result(300) # type: ignore[no-untyped-call] - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + self._set_metadata("schema_version", "1") async def _get_create_sessions_table_sql(self) -> str: owner_line = "" @@ -413,8 +517,53 @@ async def _get_create_events_table_sql(self) -> str: ) {pk}{options}{self._events_row_deletion_policy} """ + async def _get_create_app_states_table_sql(self) -> str: + return f""" +CREATE TABLE {self._app_state_table} ( + app_name STRING(128) NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true) +) PRIMARY KEY (app_name) +""" + + async def _get_create_user_states_table_sql(self) -> str: + return f""" +CREATE TABLE {self._user_state_table} ( + app_name STRING(128) NOT NULL, + user_id STRING(128) NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true) +) PRIMARY KEY (app_name, user_id) +""" + + async def _get_create_metadata_table_sql(self) -> str: + return f""" +CREATE TABLE {self._metadata_table} ( + key STRING(128) NOT NULL, + value STRING(512) NOT NULL +) PRIMARY KEY (key) +""" + + async def _get_seed_metadata_sql(self) -> str: + return f"INSERT INTO {self._metadata_table} (key, value) VALUES ('schema_version', '1')" + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE {self._events_table}", f"DROP TABLE {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE {self._events_table}", + f"DROP TABLE {self._session_table}", + ] class SpannerSyncADKMemoryStore(BaseAsyncADKMemoryStore[SpannerSyncConfig]): @@ -432,6 +581,28 @@ def __init__(self, config: SpannerSyncConfig) -> None: cast("dict[str, Any]", adk_config), "memory_ttl_seconds", "inserted_at" ) + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + def _database(self) -> "Database": return self._config.get_database() @@ -488,10 +659,6 @@ def _create_tables(self) -> None: if ddl_statements: database.update_ddl(ddl_statements).result(300) # type: ignore[no-untyped-call] - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - async def _get_create_memory_table_sql(self) -> "list[str]": owner_line = "" if self._owner_id_column_ddl: @@ -604,10 +771,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec self._run_write(statements) return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _event_exists(self, event_id: str) -> bool: sql = f"SELECT event_id FROM {self._memory_table} WHERE event_id = @event_id LIMIT 1" rows = self._run_read(sql, {"event_id": event_id}, {"event_id": SPANNER_PARAM_TYPES.STRING}) @@ -626,12 +789,6 @@ def _search_entries( return self._search_entries_fts(query, app_name, user_id, effective_limit) return self._search_entries_simple(query, app_name, user_id, effective_limit) - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT id, session_id, app_name, user_id, event_id, author, @@ -681,10 +838,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: types = {"session_id": SPANNER_PARAM_TYPES.STRING} return self._execute_update(sql, params, types) - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: cutoff = datetime.now(timezone.utc) - timedelta(days=days) sql = f"DELETE FROM {self._memory_table} WHERE inserted_at < @cutoff" @@ -692,10 +845,6 @@ def _delete_entries_older_than(self, days: int) -> int: types = {"cutoff": SPANNER_PARAM_TYPES.TIMESTAMP} return self._execute_update(sql, params, types) - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) - def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": return [ { diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py index 6104e7f66..3db5e2b72 100644 --- a/sqlspec/adapters/sqlite/adk/store.py +++ b/sqlspec/adapters/sqlite/adk/store.py @@ -113,12 +113,176 @@ def __init__(self, config: "SqliteConfig") -> None: Notes: Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") + - session_table: Sessions table name (default: "adk_session") + - events_table: Events table name (default: "adk_event") - owner_id_column: Optional owner FK column DDL (default: None) """ super().__init__(config) + async def create_tables(self) -> None: + """Create both sessions and events tables if they don't exist.""" + await async_(self._create_tables)() + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session. + + Args: + session_id: Unique session identifier. + app_name: Application name. + user_id: User identifier. + state: Initial session state. + owner_id: Optional owner ID value for owner ID column. + + Returns: + Created session record. + + Notes: + Uses Julian Day for create_time and update_time. + State is always JSON-serialized (empty dict becomes '{}', never NULL). + If owner_id_column is configured, owner_id is inserted into that column. + """ + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID. + + Args: + session_id: Session identifier. + renew_for: If positive, touch update_time while reading. + + Returns: + Session record or None if not found. + + Notes: + SQLite returns Julian Day (REAL) for timestamps. + JSON is parsed from TEXT storage. + """ + return await async_(self._get_session)(session_id, renew_for) + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + """Update session state. + + Args: + session_id: Session identifier. + state: New state dictionary (replaces existing state). + + Notes: + This replaces the entire state dictionary. + Updates update_time to current Julian Day. + Empty dict is serialized as '{}', never NULL. + """ + await async_(self._update_session_state)(session_id, state) + + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app, optionally filtered by user. + + Args: + app_name: Application name. + user_id: User identifier. If None, lists all sessions for the app. + + Returns: + List of session records ordered by update_time DESC. + + Notes: + Uses composite index on (app_name, user_id) when user_id is provided. + """ + return await async_(self._list_sessions)(app_name, user_id) + + async def delete_session(self, session_id: str) -> None: + """Delete session and all associated events (cascade). + + Args: + session_id: Session identifier. + + Notes: + Foreign key constraint ensures events are cascade-deleted. + """ + await async_(self._delete_session)(session_id) + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session. + + Args: + event_record: Event record with 5 keys: session_id, invocation_id, + author, timestamp, event_data. + + Notes: + Uses Julian Day for timestamp. + event_data dict is serialized to TEXT as event_data column. + """ + await async_(self._append_event)(event_record) + + async def append_event_and_update_state( + self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + ) -> SessionRecord: + """Atomically append an event and update the session's durable state. + + Inserts the event and updates the session state + update_time in a + single transaction, returning the updated SessionRecord via RETURNING. + + Args: + event_record: Event record to store. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (temp: keys already + stripped by the service layer). + """ + return await async_(self._append_event_and_update_state)(event_record, session_id, state) + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + """Get events for a session. + + Args: + session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. + + Returns: + List of event records ordered by timestamp ASC. + + Notes: + Uses index on (session_id, timestamp ASC). + Parses event_data TEXT back to dict for event_data field. + """ + return await async_(self._get_events)(session_id, after_timestamp, limit) + + async def delete_expired_events(self, before: datetime) -> int: + """Delete events older than the given timestamp.""" + return await async_(self._delete_expired_events)(before) + + async def delete_idle_sessions(self, updated_before: datetime) -> int: + """Delete sessions whose update_time predates the given threshold.""" + return await async_(self._delete_idle_sessions)(updated_before) + + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return await async_(self._get_app_state)(app_name) + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return await async_(self._get_user_state)(app_name, user_id) + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + await async_(self._upsert_app_state)(app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + await async_(self._upsert_user_state)(app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return await async_(self._get_metadata)(key) + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + await async_(self._set_metadata)(key, value) + def _apply_pragmas(self, connection: Any) -> None: """Apply PRAGMA optimization profile for this connection. @@ -139,7 +303,7 @@ async def _get_create_sessions_table_sql(self) -> str: """Get SQLite CREATE TABLE SQL for sessions. Returns: - SQL statement to create adk_sessions table with indexes. + SQL statement to create adk_session table with indexes. Notes: - TEXT for IDs, names, and JSON state @@ -171,7 +335,7 @@ async def _get_create_events_table_sql(self) -> str: """Get SQLite CREATE TABLE SQL for events. Returns: - SQL statement to create adk_events table with indexes. + SQL statement to create adk_event table with indexes. Notes: - TEXT for IDs and indexed scalars @@ -194,6 +358,56 @@ async def _get_create_events_table_sql(self) -> str: ON {self._events_table}(session_id, timestamp ASC); """ + async def _get_create_app_states_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for app-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name TEXT PRIMARY KEY, + state TEXT NOT NULL DEFAULT '{{}}', + update_time REAL NOT NULL + ); + """ + + async def _get_create_user_states_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for user-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name TEXT NOT NULL, + user_id TEXT NOT NULL, + state TEXT NOT NULL DEFAULT '{{}}', + update_time REAL NOT NULL, + PRIMARY KEY (app_name, user_id) + ); + """ + + async def _get_create_metadata_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for ADK internal metadata.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + """ + + async def _get_seed_metadata_sql(self) -> str: + """Get SQLite SQL that seeds the ADK schema version metadata row.""" + return f""" + INSERT OR IGNORE INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + """ + + def _get_drop_app_states_table_sql(self) -> str: + """Get SQLite DROP TABLE SQL for app-scoped state.""" + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + """Get SQLite DROP TABLE SQL for user-scoped state.""" + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + """Get SQLite DROP TABLE SQL for ADK internal metadata.""" + return f"DROP TABLE IF EXISTS {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": """Get SQLite DROP TABLE SQL statements. @@ -204,7 +418,13 @@ def _get_drop_tables_sql(self) -> "list[str]": Order matters: drop events table (child) before sessions (parent). SQLite automatically drops indexes when dropping tables. """ - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] def _create_tables(self) -> None: """Synchronous implementation of create_tables.""" @@ -212,10 +432,10 @@ def _create_tables(self) -> None: self._apply_pragmas(driver.connection) driver.execute_script(run_(self._get_create_sessions_table_sql)()) driver.execute_script(run_(self._get_create_events_table_sql)()) - - async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(run_(self._get_create_app_states_table_sql)()) + driver.execute_script(run_(self._get_create_user_states_table_sql)()) + driver.execute_script(run_(self._get_create_metadata_table_sql)()) + driver.execute_script(run_(self._get_seed_metadata_sql)()) def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -249,28 +469,6 @@ def _create_session( id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now ) - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner ID column. - - Returns: - Created session record. - - Notes: - Uses Julian Day for create_time and update_time. - State is always JSON-serialized (empty dict becomes '{}', never NULL). - If owner_id_column is configured, owner_id is inserted into that column. - """ - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": """Synchronous implementation of get_session.""" sql = f""" @@ -306,24 +504,6 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No return None raise - async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None - ) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - renew_for: If positive, touch update_time while reading. - - Returns: - Session record or None if not found. - - Notes: - SQLite returns Julian Day (REAL) for timestamps. - JSON is parsed from TEXT storage. - """ - return await async_(self._get_session)(session_id, renew_for) - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: """Synchronous implementation of update_session_state.""" now_julian = _datetime_to_julian(datetime.now(timezone.utc)) @@ -340,20 +520,6 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non conn.execute(sql, (state_json, now_julian, session_id)) conn.commit() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - - Notes: - This replaces the entire state dictionary. - Updates update_time to current Julian Day. - Empty dict is serialized as '{}', never NULL. - """ - await async_(self._update_session_state)(session_id, state) - def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionRecord]": """Synchronous implementation of list_sessions.""" if user_id is None: @@ -395,21 +561,6 @@ def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionR return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - - Notes: - Uses composite index on (app_name, user_id) when user_id is provided. - """ - return await async_(self._list_sessions)(app_name, user_id) - def _delete_session(self, session_id: str) -> None: """Synchronous implementation of delete_session.""" sql = f"DELETE FROM {self._session_table} WHERE id = ?" @@ -419,17 +570,6 @@ def _delete_session(self, session_id: str) -> None: conn.execute(sql, (session_id,)) conn.commit() - async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. - - Notes: - Foreign key constraint ensures events are cascade-deleted. - """ - await async_(self._delete_session)(session_id) - def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" timestamp_julian = _datetime_to_julian(event_record["timestamp"]) @@ -460,19 +600,6 @@ def _append_event(self, event_record: EventRecord) -> None: ) conn.commit() - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_data. - - Notes: - Uses Julian Day for timestamp. - event_data dict is serialized to TEXT as event_data column. - """ - await async_(self._append_event)(event_record) - def _append_event_and_update_state( self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" ) -> SessionRecord: @@ -528,22 +655,6 @@ def _append_event_and_update_state( update_time=_julian_to_datetime(row[5]), ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state. - - Inserts the event and updates the session state + update_time in a - single transaction, returning the updated SessionRecord via RETURNING. - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (temp: keys already - stripped by the service layer). - """ - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None ) -> "list[EventRecord]": @@ -586,25 +697,6 @@ def _get_events( return [] raise - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - - Notes: - Uses index on (session_id, timestamp ASC). - Parses event_data TEXT back to dict for event_data field. - """ - return await async_(self._get_events)(session_id, after_timestamp, limit) - def _delete_expired_events(self, before: datetime) -> int: """Synchronous implementation of delete_expired_events.""" sql = f"DELETE FROM {self._events_table} WHERE timestamp < ?" @@ -615,10 +707,6 @@ def _delete_expired_events(self, before: datetime) -> int: conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - async def delete_expired_events(self, before: datetime) -> int: - """Delete events older than the given timestamp.""" - return await async_(self._delete_expired_events)(before) - def _delete_idle_sessions(self, updated_before: datetime) -> int: """Synchronous implementation of delete_idle_sessions.""" sql = f"DELETE FROM {self._session_table} WHERE update_time < ?" @@ -629,9 +717,90 @@ def _delete_idle_sessions(self, updated_before: datetime) -> int: conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - async def delete_idle_sessions(self, updated_before: datetime) -> int: - """Delete sessions whose update_time predates the given threshold.""" - return await async_(self._delete_idle_sessions)(updated_before) + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Synchronous implementation of get_app_state.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = ?" + + try: + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + row = conn.execute(sql, (app_name,)).fetchone() + return from_json(row[0]) if row is not None and row[0] else None + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return None + raise + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Synchronous implementation of get_user_state.""" + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = ? AND user_id = ?" + + try: + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + row = conn.execute(sql, (app_name, user_id)).fetchone() + return from_json(row[0]) if row is not None and row[0] else None + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return None + raise + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Synchronous implementation of upsert_app_state.""" + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + ON CONFLICT(app_name) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + conn.execute(sql, (app_name, to_json(state), _datetime_to_julian(datetime.now(timezone.utc)))) + conn.commit() + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Synchronous implementation of upsert_user_state.""" + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + ON CONFLICT(app_name, user_id) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + conn.execute(sql, (app_name, user_id, to_json(state), _datetime_to_julian(datetime.now(timezone.utc)))) + conn.commit() + + def _get_metadata(self, key: str) -> "str | None": + """Synchronous implementation of get_metadata.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE key = ?" + + try: + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + row = conn.execute(sql, (key,)).fetchone() + return row[0] if row is not None else None + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return None + raise + + def _set_metadata(self, key: str, value: str) -> None: + """Synchronous implementation of set_metadata.""" + sql = f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES (?, ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """ + + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + conn.execute(sql, (key, value)) + conn.commit() class SqliteADKMemoryStore(BaseAsyncADKMemoryStore["SqliteConfig"]): @@ -693,6 +862,28 @@ def __init__(self, config: "SqliteConfig") -> None: """ super().__init__(config) + async def create_tables(self) -> None: + """Create tables if they don't exist.""" + await async_(self._create_tables)() + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return await async_(self._insert_memory_entries)(entries, owner_id) + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return await async_(self._search_entries)(query, app_name, user_id, limit) + + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return await async_(self._delete_entries_by_session)(session_id) + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return await async_(self._delete_entries_older_than)(days) + async def _get_create_memory_table_sql(self) -> str: """Get SQLite CREATE TABLE SQL for memory entries. @@ -797,10 +988,6 @@ def _create_tables(self) -> None: self._enable_foreign_keys(driver.connection) driver.execute_script(run_(self._get_create_memory_table_sql)()) - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication. @@ -885,10 +1072,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -919,12 +1102,6 @@ def _search_entries( logger.warning("FTS search failed; falling back to simple search: %s", exc) return self._search_entries_simple(query, app_name, user_id, effective_limit) - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT m.id, m.session_id, m.app_name, m.user_id, m.event_id, m.author, @@ -996,10 +1173,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: return deleted_count - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days. @@ -1022,7 +1195,3 @@ def _delete_entries_older_than(self, days: int) -> int: conn.commit() return deleted_count - - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) diff --git a/sqlspec/config.py b/sqlspec/config.py index 22bebd7f5..5a0eb0025 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -920,7 +920,7 @@ class ADKConfig(TypedDict): """ session_table: NotRequired[str] - """Name of the sessions table. Default: 'adk_sessions' + """Name of the sessions table. Default: 'adk_session' Examples: "agent_sessions" @@ -929,7 +929,7 @@ class ADKConfig(TypedDict): """ events_table: NotRequired[str] - """Name of the events table. Default: 'adk_events' + """Name of the events table. Default: 'adk_event' Examples: "agent_events" @@ -937,6 +937,15 @@ class ADKConfig(TypedDict): "tenant_acme_events" """ + app_state_table: NotRequired[str] + """Name of the app-scoped state table. Default: 'adk_app_state'.""" + + user_state_table: NotRequired[str] + """Name of the user-scoped state table. Default: 'adk_user_state'.""" + + metadata_table: NotRequired[str] + """Name of the internal ADK metadata table. Default: 'adk_internal_metadata'.""" + memory_table: NotRequired[str] """Name of the memory entries table. Default: 'adk_memory_entries' diff --git a/sqlspec/extensions/adk/_config_utils.py b/sqlspec/extensions/adk/_config_utils.py index d8839b25c..d723479dc 100644 --- a/sqlspec/extensions/adk/_config_utils.py +++ b/sqlspec/extensions/adk/_config_utils.py @@ -99,10 +99,10 @@ def _get_adk_session_store_config(config: _ADKConfigSource) -> _ADKSessionStoreC user_state_table = _get_first_value(schema_config.get("user_state_table"), adk_config.get("user_state_table")) metadata_table = _get_first_value(schema_config.get("metadata_table"), adk_config.get("metadata_table")) result: _ADKSessionStoreConfig = { - "session_table": str(session_table) if session_table is not None else "adk_sessions", - "events_table": str(events_table) if events_table is not None else "adk_events", - "app_state_table": str(app_state_table) if app_state_table is not None else "adk_app_states", - "user_state_table": str(user_state_table) if user_state_table is not None else "adk_user_states", + "session_table": str(session_table) if session_table is not None else "adk_session", + "events_table": str(events_table) if events_table is not None else "adk_event", + "app_state_table": str(app_state_table) if app_state_table is not None else "adk_app_state", + "user_state_table": str(user_state_table) if user_state_table is not None else "adk_user_state", "metadata_table": str(metadata_table) if metadata_table is not None else "adk_internal_metadata", } owner_id = _get_first_value(schema_config.get("owner_id_column"), adk_config.get("owner_id_column")) diff --git a/sqlspec/extensions/adk/_versioning.py b/sqlspec/extensions/adk/_versioning.py index 2ceb5ac06..4375bafdf 100644 --- a/sqlspec/extensions/adk/_versioning.py +++ b/sqlspec/extensions/adk/_versioning.py @@ -17,7 +17,7 @@ ADK_MEMORY_PAYLOAD_VERSION: Final = 1 ADK_ARTIFACT_PAYLOAD_VERSION: Final = 1 -ADK_SCHEMA_VERSION_KEY: Final = "sqlspec.adk.schema_version" +ADK_SCHEMA_VERSION_KEY: Final = "schema_version" ADK_PAYLOAD_VERSION_KEYS: Final[dict[ADKPayloadKind, str]] = { "event": "sqlspec.adk.payload.event", "state": "sqlspec.adk.payload.state", diff --git a/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py b/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py index a354792cc..b288d8ab3 100644 --- a/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +++ b/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py @@ -122,6 +122,10 @@ async def up(context: "MigrationContext | None" = None) -> "list[str]": statements = [ await store_instance._get_create_sessions_table_sql(), # pyright: ignore[reportPrivateUsage] await store_instance._get_create_events_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_create_app_states_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_create_user_states_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_create_metadata_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_seed_metadata_sql(), # pyright: ignore[reportPrivateUsage] ] if _is_memory_enabled(context): diff --git a/sqlspec/extensions/adk/service.py b/sqlspec/extensions/adk/service.py index 04205c484..d0ef811c4 100644 --- a/sqlspec/extensions/adk/service.py +++ b/sqlspec/extensions/adk/service.py @@ -11,7 +11,9 @@ compute_update_marker, event_to_record, filter_temp_state, + merge_scoped_state, record_to_session, + split_scoped_state, ) from sqlspec.utils.logging import get_logger, log_with_context @@ -86,10 +88,16 @@ async def create_session( state = {} persisted_state = filter_temp_state(state) + app_state, user_state, session_state = split_scoped_state(persisted_state) record = await self._store.create_session( - session_id=session_id, app_name=app_name, user_id=user_id, state=persisted_state + session_id=session_id, app_name=app_name, user_id=user_id, state=session_state ) + if app_state: + await self._store.upsert_app_state(app_name, app_state) + if user_state: + await self._store.upsert_user_state(app_name, user_id, user_state) + record["state"] = merge_scoped_state(record["state"], app_state, user_state) log_with_context( logger, logging.DEBUG, "adk.session.create", app_name=app_name, session_id=session_id, has_state=bool(state) ) @@ -131,6 +139,10 @@ async def get_session( ) return None + app_state = await self._store.get_app_state(app_name) + user_state = await self._store.get_user_state(app_name, user_id) + record["state"] = merge_scoped_state(record["state"], app_state, user_state) + after_timestamp = None limit = None @@ -244,6 +256,7 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": durable_state = filter_temp_state(session.state) if event.actions and event.actions.state_delta: durable_state.update(event.actions.state_delta) + app_state, user_state, session_state = split_scoped_state(durable_state) # --- Stale-session detection --- current_record = await self._store.get_session(session.id) @@ -268,8 +281,13 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": # --- Persist event and state atomically --- updated_record = await self._store.append_event_and_update_state( - event_record=event_record, session_id=session.id, state=durable_state + event_record=event_record, session_id=session.id, state=session_state ) + if app_state: + await self._store.upsert_app_state(session.app_name, app_state) + if user_state: + await self._store.upsert_user_state(session.app_name, session.user_id, user_state) + updated_record["state"] = merge_scoped_state(updated_record["state"], app_state, user_state) # Use the returned record directly — saves a round-trip vs a follow-up get_session(). session.last_update_time = updated_record["update_time"].timestamp() diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index 37921f659..e0509a532 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -75,8 +75,11 @@ class BaseAsyncADKStore(ABC, Generic[ConfigT]): Notes: Configuration is read from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") + - session_table: Sessions table name (default: "adk_session") + - events_table: Events table name (default: "adk_event") + - app_state_table: App-scoped state table name (default: "adk_app_state") + - user_state_table: User-scoped state table name (default: "adk_user_state") + - metadata_table: Internal metadata table name (default: "adk_internal_metadata") - owner_id_column: Optional owner FK column DDL (default: None) """ @@ -99,8 +102,11 @@ def __init__(self, config: ConfigT) -> None: Notes: Reads configuration from config.extension_config["adk"]: - - session_table: Sessions table name (default: "adk_sessions") - - events_table: Events table name (default: "adk_events") + - session_table: Sessions table name (default: "adk_session") + - events_table: Events table name (default: "adk_event") + - app_state_table: App-scoped state table name (default: "adk_app_state") + - user_state_table: User-scoped state table name (default: "adk_user_state") + - metadata_table: Internal metadata table name (default: "adk_internal_metadata") - owner_id_column: Optional owner FK column DDL (default: None) """ self._config = config @@ -120,14 +126,6 @@ def __init__(self, config: ConfigT) -> None: validate_identifier(self._user_state_table, label="table name") validate_identifier(self._metadata_table, label="table name") - def _get_store_config_from_extension(self) -> "dict[str, Any]": - """Extract ADK store configuration from config.extension_config. - - Returns: - Dict with session_table, events_table, and optionally owner_id_column. - """ - return dict(_get_adk_session_store_config(self._config)) - @property def config(self) -> ConfigT: """Return the database configuration.""" @@ -168,6 +166,14 @@ def owner_id_column_name(self) -> "str | None": """Return the owner ID column name only (or None if not configured).""" return self._owner_id_column_name + def _get_store_config_from_extension(self) -> "dict[str, Any]": + """Extract ADK store configuration from config.extension_config. + + Returns: + Dict with ADK table names and optionally owner_id_column. + """ + return dict(_get_adk_session_store_config(self._config)) + def _calculate_expires_at(self, expires_in: "int | timedelta | None") -> "datetime | None": """Calculate expiration timestamp from expires_in. @@ -340,6 +346,74 @@ async def delete_idle_sessions(self, updated_before: datetime) -> int: """ raise NotImplementedError + @abstractmethod + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application. + + Args: + app_name: Application name. + + Returns: + App-scoped state mapping if present, otherwise ``None``. + """ + raise NotImplementedError + + @abstractmethod + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user. + + Args: + app_name: Application name. + user_id: User identifier. + + Returns: + User-scoped state mapping if present, otherwise ``None``. + """ + raise NotImplementedError + + @abstractmethod + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application. + + Args: + app_name: Application name. + state: App-scoped state mapping. + """ + raise NotImplementedError + + @abstractmethod + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user. + + Args: + app_name: Application name. + user_id: User identifier. + state: User-scoped state mapping. + """ + raise NotImplementedError + + @abstractmethod + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table. + + Args: + key: Metadata key. + + Returns: + Metadata value if present, otherwise ``None``. + """ + raise NotImplementedError + + @abstractmethod + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table. + + Args: + key: Metadata key. + value: Metadata value. + """ + raise NotImplementedError + @abstractmethod async def create_tables(self) -> None: """Create the sessions and events tables if they don't exist.""" @@ -406,6 +480,69 @@ async def _get_create_events_table_sql(self) -> str: """ raise NotImplementedError + @abstractmethod + async def _get_create_app_states_table_sql(self) -> str: + """Get the CREATE TABLE SQL for the app-scoped state table. + + Returns: + SQL statement to create the app-scoped state table. + """ + raise NotImplementedError + + @abstractmethod + async def _get_create_user_states_table_sql(self) -> str: + """Get the CREATE TABLE SQL for the user-scoped state table. + + Returns: + SQL statement to create the user-scoped state table. + """ + raise NotImplementedError + + @abstractmethod + async def _get_create_metadata_table_sql(self) -> str: + """Get the CREATE TABLE SQL for the ADK internal metadata table. + + Returns: + SQL statement to create the ADK internal metadata table. + """ + raise NotImplementedError + + @abstractmethod + async def _get_seed_metadata_sql(self) -> str: + """Get the SQL statement that seeds the ADK schema-version metadata row. + + Returns: + SQL statement that records ``schema_version = 1``. + """ + raise NotImplementedError + + @abstractmethod + def _get_drop_app_states_table_sql(self) -> str: + """Get the DROP TABLE SQL statement for the app-scoped state table. + + Returns: + SQL statement to drop the app-scoped state table. + """ + raise NotImplementedError + + @abstractmethod + def _get_drop_user_states_table_sql(self) -> str: + """Get the DROP TABLE SQL statement for the user-scoped state table. + + Returns: + SQL statement to drop the user-scoped state table. + """ + raise NotImplementedError + + @abstractmethod + def _get_drop_metadata_table_sql(self) -> str: + """Get the DROP TABLE SQL statement for the ADK internal metadata table. + + Returns: + SQL statement to drop the ADK internal metadata table. + """ + raise NotImplementedError + @abstractmethod def _get_drop_tables_sql(self) -> "list[str]": """Get the DROP TABLE SQL statements for this database dialect. diff --git a/tests/integration/adapters/_adk_contract_helpers.py b/tests/integration/adapters/_adk_contract_helpers.py index 685a80e97..cab6cda8d 100644 --- a/tests/integration/adapters/_adk_contract_helpers.py +++ b/tests/integration/adapters/_adk_contract_helpers.py @@ -6,12 +6,14 @@ from uuid import uuid4 from sqlspec.extensions.adk import EventRecord, MemoryRecord, SessionRecord +from sqlspec.extensions.adk.service import SQLSpecSessionService __all__ = ( "assert_memory_store_contract", "assert_session_event_cleanup_contract", "assert_session_event_store_contract", "assert_session_get_session_renewal_contract", + "assert_session_scoped_state_contract", "assert_session_table_lifecycle_contract", ) @@ -47,6 +49,14 @@ async def delete_expired_events(self, before: datetime) -> int: ... async def delete_idle_sessions(self, updated_before: datetime) -> int: ... + async def get_app_state(self, app_name: str) -> dict[str, object] | None: ... + + async def get_user_state(self, app_name: str, user_id: str) -> dict[str, object] | None: ... + + async def upsert_app_state(self, app_name: str, state: dict[str, object]) -> None: ... + + async def upsert_user_state(self, app_name: str, user_id: str, state: dict[str, object]) -> None: ... + async def drop_tables(self) -> None: ... async def recreate_tables(self) -> None: ... @@ -239,6 +249,63 @@ async def assert_session_get_session_renewal_contract(store: SessionEventStore, assert renewed["state"] == {"renew": True} +async def assert_session_scoped_state_contract(store: SessionEventStore, *, marker: str) -> None: + """Assert service-level app:/user:/temp: semantics over a real store.""" + from google.adk.events.event import Event + from google.adk.events.event_actions import EventActions + + service = SQLSpecSessionService(store) # type: ignore[arg-type] + app_name = _contract_key(marker, "scoped-app") + user_id = _contract_key(marker, "scoped-user") + other_user_id = _contract_key(marker, "scoped-other-user") + + session_a = await service.create_session( + app_name=app_name, + user_id=user_id, + session_id=_contract_key(marker, "scoped-session-a"), + state={"session_seed": "a", "temp:create": "drop"}, + ) + session_b = await service.create_session( + app_name=app_name, user_id=user_id, session_id=_contract_key(marker, "scoped-session-b"), state={} + ) + other_user_session = await service.create_session( + app_name=app_name, + user_id=other_user_id, + session_id=_contract_key(marker, "scoped-session-other-user"), + state={}, + ) + session_a = await service.get_session(app_name=app_name, user_id=user_id, session_id=session_a.id) + assert session_a is not None + + event = Event( + invocation_id=_contract_key(marker, "scoped-invocation"), + author="user", + timestamp=datetime.now(timezone.utc).timestamp(), + actions=EventActions(state_delta={"app:counter": 1, "user:theme": "dark", "turn": 1, "temp:scratch": "drop"}), + ) + await service.append_event(session_a, event) + + raw_session = await store.get_session(session_a.id) + assert raw_session is not None + assert raw_session["state"] == {"session_seed": "a", "turn": 1} + assert await store.get_app_state(app_name) == {"app:counter": 1} + assert await store.get_user_state(app_name, user_id) == {"user:theme": "dark"} + + fetched_a = await service.get_session(app_name=app_name, user_id=user_id, session_id=session_a.id) + assert fetched_a is not None + assert fetched_a.state == {"session_seed": "a", "turn": 1, "app:counter": 1, "user:theme": "dark"} + + fetched_b = await service.get_session(app_name=app_name, user_id=user_id, session_id=session_b.id) + assert fetched_b is not None + assert fetched_b.state == {"app:counter": 1, "user:theme": "dark"} + + fetched_other_user = await service.get_session( + app_name=app_name, user_id=other_user_id, session_id=other_user_session.id + ) + assert fetched_other_user is not None + assert fetched_other_user.state == {"app:counter": 1} + + async def assert_session_table_lifecycle_contract(store: SessionEventStore, *, marker: str) -> None: """Assert ADK stores can drop and recreate their managed session tables.""" app_name = _contract_key(marker, "lifecycle-app") diff --git a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py index d06c20425..74c3ee353 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py @@ -10,6 +10,7 @@ from tests.integration.adapters._adk_contract_helpers import ( assert_session_event_cleanup_contract, assert_session_get_session_renewal_contract, + assert_session_scoped_state_contract, assert_session_table_lifecycle_contract, ) @@ -106,6 +107,11 @@ async def test_session_get_session_renewal_contract(adbc_store: Any) -> None: await assert_session_get_session_renewal_contract(adbc_store, marker="adbc") +async def test_session_scoped_state_contract(adbc_store: Any) -> None: + """ADBC service reads merge app/user state from dedicated scoped tables.""" + await assert_session_scoped_state_contract(adbc_store, marker="adbc") + + async def test_session_table_lifecycle_contract(adbc_store: Any) -> None: """ADBC can drop and recreate its ADK session tables programmatically.""" await assert_session_table_lifecycle_contract(adbc_store, marker="adbc") diff --git a/tests/integration/adapters/aiomysql/extensions/adk/conftest.py b/tests/integration/adapters/aiomysql/extensions/adk/conftest.py index 2bdd4285f..dfa0ebfb1 100644 --- a/tests/integration/adapters/aiomysql/extensions/adk/conftest.py +++ b/tests/integration/adapters/aiomysql/extensions/adk/conftest.py @@ -35,7 +35,15 @@ async def aiomysql_adk_store(mysql_service: MySQLService) -> "AsyncGenerator[Aio "minsize": 1, "maxsize": 5, }, - extension_config={"adk": {"session_table": "test_sessions", "events_table": "test_events"}}, + extension_config={ + "adk": { + "session_table": "test_sessions", + "events_table": "test_events", + "app_state_table": "test_app_state", + "user_state_table": "test_user_state", + "metadata_table": "test_adk_metadata", + } + }, ) try: @@ -47,6 +55,9 @@ async def aiomysql_adk_store(mysql_service: MySQLService) -> "AsyncGenerator[Aio async with config.provide_connection() as conn, AiomysqlCursor(conn) as cursor: await cursor.execute("DROP TABLE IF EXISTS test_events") await cursor.execute("DROP TABLE IF EXISTS test_sessions") + await cursor.execute("DROP TABLE IF EXISTS test_user_state") + await cursor.execute("DROP TABLE IF EXISTS test_app_state") + await cursor.execute("DROP TABLE IF EXISTS test_adk_metadata") await conn.commit() finally: pool = config.connection_instance @@ -85,6 +96,9 @@ async def aiomysql_adk_store_with_fk(mysql_service: MySQLService) -> "AsyncGener "adk": { "session_table": "test_fk_sessions", "events_table": "test_fk_events", + "app_state_table": "test_fk_app_state", + "user_state_table": "test_fk_user_state", + "metadata_table": "test_fk_adk_metadata", "owner_id_column": "tenant_id BIGINT NOT NULL REFERENCES test_tenants(id) ON DELETE CASCADE", } }, @@ -109,6 +123,9 @@ async def aiomysql_adk_store_with_fk(mysql_service: MySQLService) -> "AsyncGener async with config.provide_connection() as conn, AiomysqlCursor(conn) as cursor: await cursor.execute("DROP TABLE IF EXISTS test_fk_events") await cursor.execute("DROP TABLE IF EXISTS test_fk_sessions") + await cursor.execute("DROP TABLE IF EXISTS test_fk_user_state") + await cursor.execute("DROP TABLE IF EXISTS test_fk_app_state") + await cursor.execute("DROP TABLE IF EXISTS test_fk_adk_metadata") await cursor.execute("DROP TABLE IF EXISTS test_tenants") await conn.commit() finally: diff --git a/tests/integration/adapters/aiomysql/extensions/adk/test_store.py b/tests/integration/adapters/aiomysql/extensions/adk/test_store.py index b13389bd3..a162ced06 100644 --- a/tests/integration/adapters/aiomysql/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiomysql/extensions/adk/test_store.py @@ -8,6 +8,7 @@ from sqlspec.adapters.aiomysql._typing import AiomysqlCursor from sqlspec.adapters.aiomysql.adk import AiomysqlADKStore from sqlspec.extensions.adk import EventRecord +from tests.integration.adapters._adk_contract_helpers import assert_session_scoped_state_contract pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.aiomysql, pytest.mark.integration] @@ -18,6 +19,11 @@ async def test_create_tables(aiomysql_adk_store: AiomysqlADKStore) -> None: assert aiomysql_adk_store.events_table == "test_events" +async def test_aiomysql_session_scoped_state_contract(aiomysql_adk_store: AiomysqlADKStore) -> None: + """Aiomysql service reads merge app/user state from dedicated scoped tables.""" + await assert_session_scoped_state_contract(aiomysql_adk_store, marker="aiomysql") + + async def test_storage_types_verification(aiomysql_adk_store: AiomysqlADKStore) -> None: """Verify MySQL uses JSON type (not TEXT) and TIMESTAMP(6) for microseconds. diff --git a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py index 7a0179103..81d1e9a41 100644 --- a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py @@ -12,6 +12,7 @@ assert_session_event_cleanup_contract, assert_session_event_store_contract, assert_session_get_session_renewal_contract, + assert_session_scoped_state_contract, assert_session_table_lifecycle_contract, ) @@ -38,7 +39,7 @@ async def test_aiosqlite_session_owner_column_is_created_when_configured(tmp_pat await store.create_session("session-owner", "app", "user", {}, owner_id="tenant-1") async with config.provide_connection() as conn: - cursor = await conn.execute("SELECT owner_id FROM adk_sessions WHERE id = ?", ("session-owner",)) + cursor = await conn.execute("SELECT owner_id FROM adk_session WHERE id = ?", ("session-owner",)) row = await cursor.fetchone() assert row == ("tenant-1",) @@ -87,6 +88,15 @@ async def test_aiosqlite_session_get_session_renewal_contract(tmp_path: Path) -> await config.close_pool() +async def test_aiosqlite_session_scoped_state_contract(tmp_path: Path) -> None: + """AioSQLite service reads merge app/user state from dedicated scoped tables.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_scoped_state_contract(store, marker="aiosqlite") + finally: + await config.close_pool() + + async def test_aiosqlite_session_table_lifecycle_contract(tmp_path: Path) -> None: """AioSQLite can drop and recreate its ADK session tables programmatically.""" config, store = await _build_store(tmp_path) diff --git a/tests/integration/adapters/asyncmy/extensions/adk/conftest.py b/tests/integration/adapters/asyncmy/extensions/adk/conftest.py index 93032dc2b..99982a451 100644 --- a/tests/integration/adapters/asyncmy/extensions/adk/conftest.py +++ b/tests/integration/adapters/asyncmy/extensions/adk/conftest.py @@ -34,7 +34,15 @@ async def asyncmy_adk_store(mysql_service: MySQLService) -> "AsyncGenerator[Asyn "minsize": 1, "maxsize": 5, }, - extension_config={"adk": {"session_table": "test_sessions", "events_table": "test_events"}}, + extension_config={ + "adk": { + "session_table": "test_sessions", + "events_table": "test_events", + "app_state_table": "test_app_state", + "user_state_table": "test_user_state", + "metadata_table": "test_adk_metadata", + } + }, ) try: @@ -46,6 +54,9 @@ async def asyncmy_adk_store(mysql_service: MySQLService) -> "AsyncGenerator[Asyn async with config.provide_connection() as conn, conn.cursor() as cursor: await cursor.execute("DROP TABLE IF EXISTS test_events") await cursor.execute("DROP TABLE IF EXISTS test_sessions") + await cursor.execute("DROP TABLE IF EXISTS test_user_state") + await cursor.execute("DROP TABLE IF EXISTS test_app_state") + await cursor.execute("DROP TABLE IF EXISTS test_adk_metadata") await conn.commit() finally: pool = config.connection_instance @@ -84,6 +95,9 @@ async def asyncmy_adk_store_with_fk(mysql_service: MySQLService) -> "AsyncGenera "adk": { "session_table": "test_fk_sessions", "events_table": "test_fk_events", + "app_state_table": "test_fk_app_state", + "user_state_table": "test_fk_user_state", + "metadata_table": "test_fk_adk_metadata", "owner_id_column": "tenant_id BIGINT NOT NULL REFERENCES test_tenants(id) ON DELETE CASCADE", } }, @@ -108,6 +122,9 @@ async def asyncmy_adk_store_with_fk(mysql_service: MySQLService) -> "AsyncGenera async with config.provide_connection() as conn, conn.cursor() as cursor: await cursor.execute("DROP TABLE IF EXISTS test_fk_events") await cursor.execute("DROP TABLE IF EXISTS test_fk_sessions") + await cursor.execute("DROP TABLE IF EXISTS test_fk_user_state") + await cursor.execute("DROP TABLE IF EXISTS test_fk_app_state") + await cursor.execute("DROP TABLE IF EXISTS test_fk_adk_metadata") await cursor.execute("DROP TABLE IF EXISTS test_tenants") await conn.commit() finally: diff --git a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py index 1a8f5a2b3..92ce3400c 100644 --- a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py +++ b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py @@ -7,6 +7,7 @@ from sqlspec.adapters.asyncmy.adk import AsyncmyADKStore from sqlspec.extensions.adk import EventRecord +from tests.integration.adapters._adk_contract_helpers import assert_session_scoped_state_contract pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.asyncmy, pytest.mark.integration] @@ -17,6 +18,11 @@ async def test_create_tables(asyncmy_adk_store: AsyncmyADKStore) -> None: assert asyncmy_adk_store.events_table == "test_events" +async def test_asyncmy_session_scoped_state_contract(asyncmy_adk_store: AsyncmyADKStore) -> None: + """Asyncmy service reads merge app/user state from dedicated scoped tables.""" + await assert_session_scoped_state_contract(asyncmy_adk_store, marker="asyncmy") + + async def test_storage_types_verification(asyncmy_adk_store: AsyncmyADKStore) -> None: """Verify MySQL uses JSON type (not TEXT) and TIMESTAMP(6) for microseconds. diff --git a/tests/integration/adapters/asyncpg/extensions/adk/conftest.py b/tests/integration/adapters/asyncpg/extensions/adk/conftest.py index 25dd4e0c1..45e88839f 100644 --- a/tests/integration/adapters/asyncpg/extensions/adk/conftest.py +++ b/tests/integration/adapters/asyncpg/extensions/adk/conftest.py @@ -43,8 +43,11 @@ async def asyncpg_adk_store(postgres_service: "PostgresService") -> "AsyncGenera yield store async with config.provide_connection() as conn: - await conn.execute("DROP TABLE IF EXISTS adk_events CASCADE") - await conn.execute("DROP TABLE IF EXISTS adk_sessions CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_event CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_session CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_user_state CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_app_state CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_internal_metadata CASCADE") finally: if config.connection_instance: await config.close_pool() diff --git a/tests/integration/adapters/asyncpg/extensions/adk/test_owner_id_column.py b/tests/integration/adapters/asyncpg/extensions/adk/test_owner_id_column.py index 1e4606c9c..1cc4c3eb7 100644 --- a/tests/integration/adapters/asyncpg/extensions/adk/test_owner_id_column.py +++ b/tests/integration/adapters/asyncpg/extensions/adk/test_owner_id_column.py @@ -16,8 +16,8 @@ def _make_config_with_owner_id( postgres_service: Any, owner_id_column: "str | None" = None, - session_table: str = "adk_sessions", - events_table: str = "adk_events", + session_table: str = "adk_session", + events_table: str = "adk_event", ) -> AsyncpgConfig: """Helper to create config with ADK extension config.""" extension_config = cast("ExtensionConfigs", {"adk": {"session_table": session_table, "events_table": events_table}}) @@ -68,8 +68,8 @@ async def tenants_table(asyncpg_config_for_fk: AsyncpgConfig) -> "AsyncGenerator yield async with asyncpg_config_for_fk.provide_connection() as conn: - await conn.execute("DROP TABLE IF EXISTS adk_events CASCADE") - await conn.execute("DROP TABLE IF EXISTS adk_sessions CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_event CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_session CASCADE") await conn.execute("DROP TABLE IF EXISTS tenants CASCADE") @@ -93,8 +93,8 @@ async def users_table(asyncpg_config_for_fk: AsyncpgConfig) -> "AsyncGenerator[N yield async with asyncpg_config_for_fk.provide_connection() as conn: - await conn.execute("DROP TABLE IF EXISTS adk_events CASCADE") - await conn.execute("DROP TABLE IF EXISTS adk_sessions CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_event CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_session CASCADE") await conn.execute("DROP TABLE IF EXISTS users CASCADE") @@ -111,8 +111,8 @@ async def test_store_without_owner_id_column(asyncpg_config_for_fk: AsyncpgConfi assert session["state"] == {"data": "test"} async with asyncpg_config_for_fk.provide_connection() as conn: - await conn.execute("DROP TABLE IF EXISTS adk_events CASCADE") - await conn.execute("DROP TABLE IF EXISTS adk_sessions CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_event CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_session CASCADE") async def test_create_tables_with_owner_id_column( @@ -129,7 +129,7 @@ async def test_create_tables_with_owner_id_column( result = await conn.fetchrow(""" SELECT column_name, data_type, is_nullable FROM information_schema.columns - WHERE table_name = 'adk_sessions' AND column_name = 'tenant_id' + WHERE table_name = 'adk_session' AND column_name = 'tenant_id' """) assert result is not None @@ -158,7 +158,7 @@ async def test_create_session_with_owner_id(tenants_table: Any, postgres_service assert session["state"] == {"data": "test"} async with config.provide_connection() as conn: - result = await conn.fetchrow("SELECT tenant_id FROM adk_sessions WHERE id = $1", "session-1") + result = await conn.fetchrow("SELECT tenant_id FROM adk_session WHERE id = $1", "session-1") assert result is not None assert result["tenant_id"] == 1 finally: @@ -242,7 +242,7 @@ async def test_nullable_owner_id_column(tenants_table: Any, postgres_service: An assert session is not None async with config.provide_connection() as conn: - result = await conn.fetchrow("SELECT tenant_id FROM adk_sessions WHERE id = $1", "session-1") + result = await conn.fetchrow("SELECT tenant_id FROM adk_session WHERE id = $1", "session-1") assert result is not None assert result["tenant_id"] is None finally: @@ -262,13 +262,13 @@ async def test_set_null_on_delete_behavior(tenants_table: Any, postgres_service: await store.create_session("session-1", "app-1", "user-1", {"data": "test"}, owner_id=1) async with config.provide_connection() as conn: - result = await conn.fetchrow("SELECT tenant_id FROM adk_sessions WHERE id = $1", "session-1") + result = await conn.fetchrow("SELECT tenant_id FROM adk_session WHERE id = $1", "session-1") assert result is not None assert result["tenant_id"] == 1 await conn.execute("DELETE FROM tenants WHERE id = 1") - result = await conn.fetchrow("SELECT tenant_id FROM adk_sessions WHERE id = $1", "session-1") + result = await conn.fetchrow("SELECT tenant_id FROM adk_session WHERE id = $1", "session-1") assert result is not None assert result["tenant_id"] is None finally: @@ -294,7 +294,7 @@ async def test_uuid_owner_id_column(users_table: Any, postgres_service: Any) -> assert session is not None async with config.provide_connection() as conn: - result = await conn.fetchrow("SELECT account_id FROM adk_sessions WHERE id = $1", "session-1") + result = await conn.fetchrow("SELECT account_id FROM adk_session WHERE id = $1", "session-1") assert result is not None assert result["account_id"] == user_uuid finally: @@ -336,8 +336,8 @@ async def test_backwards_compatibility_without_owner_id(asyncpg_config_for_fk: A assert sessions[0]["id"] == "session-1" async with asyncpg_config_for_fk.provide_connection() as conn: - await conn.execute("DROP TABLE IF EXISTS adk_events CASCADE") - await conn.execute("DROP TABLE IF EXISTS adk_sessions CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_event CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_session CASCADE") async def test_owner_id_column_name_property(tenants_table: Any, postgres_service: Any) -> None: @@ -375,7 +375,7 @@ async def test_multiple_sessions_same_tenant(tenants_table: Any, postgres_servic await store.create_session(f"session-{i}", "app-1", f"user-{i}", {"session_num": i}, owner_id=1) async with config.provide_connection() as conn: - result = await conn.fetch("SELECT id FROM adk_sessions WHERE tenant_id = $1 ORDER BY id", 1) + result = await conn.fetch("SELECT id FROM adk_session WHERE tenant_id = $1 ORDER BY id", 1) assert len(result) == 5 assert [r["id"] for r in result] == [f"session-{i}" for i in range(5)] finally: diff --git a/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py b/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py index cc6ac7fad..435765823 100644 --- a/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py +++ b/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py @@ -4,6 +4,8 @@ import pytest +from tests.integration.adapters._adk_contract_helpers import assert_session_scoped_state_contract + pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.asyncpg, pytest.mark.integration] @@ -40,6 +42,11 @@ async def test_get_session(asyncpg_adk_store: Any) -> None: assert retrieved["state"] == state +async def test_asyncpg_session_scoped_state_contract(asyncpg_adk_store: Any) -> None: + """Asyncpg service reads merge app/user state from dedicated scoped tables.""" + await assert_session_scoped_state_contract(asyncpg_adk_store, marker="asyncpg") + + async def test_get_nonexistent_session(asyncpg_adk_store: Any) -> None: """Test retrieving a session that doesn't exist.""" result = await asyncpg_adk_store.get_session("nonexistent") diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index 73088b2a1..ae2e7dee3 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -14,6 +14,7 @@ assert_session_event_cleanup_contract, assert_session_event_store_contract, assert_session_get_session_renewal_contract, + assert_session_scoped_state_contract, assert_session_table_lifecycle_contract, ) @@ -68,6 +69,11 @@ async def test_duckdb_session_get_session_renewal_contract(duckdb_adk_store: Duc await assert_session_get_session_renewal_contract(duckdb_adk_store, marker="duckdb") +async def test_duckdb_session_scoped_state_contract(duckdb_adk_store: DuckdbADKStore) -> None: + """DuckDB service reads merge app/user state from dedicated scoped tables.""" + await assert_session_scoped_state_contract(duckdb_adk_store, marker="duckdb") + + async def test_duckdb_session_table_lifecycle_contract(duckdb_adk_store: DuckdbADKStore) -> None: """DuckDB can drop and recreate its ADK session tables programmatically.""" await assert_session_table_lifecycle_contract(duckdb_adk_store, marker="duckdb") diff --git a/tests/integration/adapters/mysqlconnector/extensions/adk/conftest.py b/tests/integration/adapters/mysqlconnector/extensions/adk/conftest.py index 561ef9683..6b4a41ae1 100644 --- a/tests/integration/adapters/mysqlconnector/extensions/adk/conftest.py +++ b/tests/integration/adapters/mysqlconnector/extensions/adk/conftest.py @@ -22,7 +22,15 @@ async def mysqlconnector_adk_store(mysql_service: MySQLService) -> "AsyncGenerat "autocommit": False, "use_pure": True, }, - extension_config={"adk": {"session_table": "test_sessions", "events_table": "test_events"}}, + extension_config={ + "adk": { + "session_table": "test_sessions", + "events_table": "test_events", + "app_state_table": "test_app_state", + "user_state_table": "test_user_state", + "metadata_table": "test_adk_metadata", + } + }, ) try: @@ -36,6 +44,9 @@ async def mysqlconnector_adk_store(mysql_service: MySQLService) -> "AsyncGenerat try: await cursor.execute("DROP TABLE IF EXISTS test_events") await cursor.execute("DROP TABLE IF EXISTS test_sessions") + await cursor.execute("DROP TABLE IF EXISTS test_user_state") + await cursor.execute("DROP TABLE IF EXISTS test_app_state") + await cursor.execute("DROP TABLE IF EXISTS test_adk_metadata") await conn.commit() finally: await cursor.close() @@ -62,6 +73,9 @@ async def mysqlconnector_adk_store_with_fk( "adk": { "session_table": "test_fk_sessions", "events_table": "test_fk_events", + "app_state_table": "test_fk_app_state", + "user_state_table": "test_fk_user_state", + "metadata_table": "test_fk_adk_metadata", "owner_id_column": "tenant_id BIGINT NOT NULL REFERENCES test_tenants(id) ON DELETE CASCADE", } }, @@ -92,6 +106,9 @@ async def mysqlconnector_adk_store_with_fk( try: await cursor.execute("DROP TABLE IF EXISTS test_fk_events") await cursor.execute("DROP TABLE IF EXISTS test_fk_sessions") + await cursor.execute("DROP TABLE IF EXISTS test_fk_user_state") + await cursor.execute("DROP TABLE IF EXISTS test_fk_app_state") + await cursor.execute("DROP TABLE IF EXISTS test_fk_adk_metadata") await cursor.execute("DROP TABLE IF EXISTS test_tenants") await conn.commit() finally: diff --git a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py index ffe142eef..ba288bd1f 100644 --- a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py +++ b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py @@ -8,6 +8,7 @@ from sqlspec.adapters.mysqlconnector.adk import MysqlConnectorAsyncADKStore from sqlspec.extensions.adk import EventRecord +from tests.integration.adapters._adk_contract_helpers import assert_session_scoped_state_contract pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.mysql_connector, pytest.mark.integration] @@ -18,6 +19,13 @@ async def test_create_tables(mysqlconnector_adk_store: MysqlConnectorAsyncADKSto assert mysqlconnector_adk_store.events_table == "test_events" +async def test_mysqlconnector_session_scoped_state_contract( + mysqlconnector_adk_store: MysqlConnectorAsyncADKStore, +) -> None: + """MysqlConnector service reads merge app/user state from dedicated scoped tables.""" + await assert_session_scoped_state_contract(mysqlconnector_adk_store, marker="mysqlconnector") + + async def test_storage_types_verification(mysqlconnector_adk_store: MysqlConnectorAsyncADKStore) -> None: """Verify MySQL uses JSON type (not TEXT) and TIMESTAMP(6) for microseconds.""" config = mysqlconnector_adk_store.config diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py b/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py index 5a32a911b..59f4a3d08 100644 --- a/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py +++ b/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py @@ -35,7 +35,7 @@ async def test_inmemory_enabled_creates_sessions_table_with_inmemory_async( """ SELECT inmemory, inmemory_priority, inmemory_distribute FROM user_tables - WHERE table_name = 'ADK_SESSIONS' + WHERE table_name = 'ADK_SESSION' """ ) row = await cursor.fetchone() @@ -77,7 +77,7 @@ async def test_inmemory_enabled_creates_events_table_with_inmemory_async( """ SELECT inmemory, inmemory_priority, inmemory_distribute FROM user_tables - WHERE table_name = 'ADK_EVENTS' + WHERE table_name = 'ADK_EVENT' """ ) row = await cursor.fetchone() @@ -115,7 +115,7 @@ async def test_inmemory_disabled_creates_tables_without_inmemory_async(oracle_as """ SELECT inmemory, inmemory_priority, inmemory_distribute FROM user_tables - WHERE table_name IN ('ADK_SESSIONS', 'ADK_EVENTS') + WHERE table_name IN ('ADK_SESSION', 'ADK_EVENT') ORDER BY table_name """ ) @@ -153,7 +153,7 @@ async def test_inmemory_default_disabled_async(oracle_async_config: OracleAsyncC """ SELECT inmemory FROM user_tables - WHERE table_name = 'ADK_SESSIONS' + WHERE table_name = 'ADK_SESSION' """ ) row = await cursor.fetchone() @@ -214,7 +214,7 @@ async def test_inmemory_with_owner_id_column_async(oracle_async_config: OracleAs SELECT inmemory, column_name FROM user_tables t LEFT JOIN user_tab_columns c ON t.table_name = c.table_name - WHERE t.table_name = 'ADK_SESSIONS' AND (c.column_name = 'OWNER_ID' OR c.column_name IS NULL) + WHERE t.table_name = 'ADK_SESSION' AND (c.column_name = 'OWNER_ID' OR c.column_name IS NULL) """ ) rows = await cursor.fetchall() @@ -318,7 +318,7 @@ async def test_inmemory_enabled_sync(oracle_sync_config: OracleSyncConfig) -> No """ SELECT inmemory, inmemory_priority FROM user_tables - WHERE table_name IN ('ADK_SESSIONS', 'ADK_EVENTS') + WHERE table_name IN ('ADK_SESSION', 'ADK_EVENT') ORDER BY table_name """ ) @@ -359,7 +359,7 @@ async def test_inmemory_disabled_sync(oracle_sync_config: OracleSyncConfig) -> N """ SELECT inmemory, inmemory_priority FROM user_tables - WHERE table_name IN ('ADK_SESSIONS', 'ADK_EVENTS') + WHERE table_name IN ('ADK_SESSION', 'ADK_EVENT') """ ) rows = cursor.fetchall() diff --git a/tests/integration/adapters/spanner/extensions/adk/conftest.py b/tests/integration/adapters/spanner/extensions/adk/conftest.py index 57ad9bace..2e2a983d3 100644 --- a/tests/integration/adapters/spanner/extensions/adk/conftest.py +++ b/tests/integration/adapters/spanner/extensions/adk/conftest.py @@ -25,7 +25,7 @@ def spanner_adk_config(spanner_service: SpannerService, spanner_database: "Datab "min_sessions": 1, "max_sessions": 5, }, - extension_config={"adk": {"session_table": "adk_sessions", "events_table": "adk_events"}}, + extension_config={"adk": {"session_table": "adk_session", "events_table": "adk_event"}}, ) diff --git a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py index 5ff1ee911..29dcf2a53 100644 --- a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py +++ b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py @@ -7,6 +7,7 @@ import pytest from sqlspec.extensions.adk import EventRecord +from tests.integration.adapters._adk_contract_helpers import assert_session_scoped_state_contract pytestmark = [pytest.mark.spanner, pytest.mark.integration] @@ -22,6 +23,11 @@ async def test_create_and_get_session(spanner_adk_store: Any) -> None: assert fetched["state"] == {"a": 1} +async def test_spanner_session_scoped_state_contract(spanner_adk_store: Any) -> None: + """Spanner service reads merge app/user state from dedicated scoped tables.""" + await assert_session_scoped_state_contract(spanner_adk_store, marker="spanner") + + async def test_update_session_state(spanner_adk_store: Any) -> None: session_id = "session-update" await spanner_adk_store.delete_session(session_id) diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_store.py index 2444c59e0..9cc86efc2 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_store.py @@ -12,6 +12,7 @@ assert_session_event_cleanup_contract, assert_session_event_store_contract, assert_session_get_session_renewal_contract, + assert_session_scoped_state_contract, assert_session_table_lifecycle_contract, ) @@ -67,6 +68,15 @@ async def test_sqlite_session_get_session_renewal_contract(tmp_path: Path) -> No config.close_pool() +async def test_sqlite_session_scoped_state_contract(tmp_path: Path) -> None: + """SQLite service reads merge app/user state from dedicated scoped tables.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_scoped_state_contract(store, marker="sqlite") + finally: + config.close_pool() + + async def test_sqlite_session_table_lifecycle_contract(tmp_path: Path) -> None: """SQLite can drop and recreate its ADK session tables programmatically.""" config, store = await _build_store(tmp_path) diff --git a/tests/unit/extensions/test_adk/test_service.py b/tests/unit/extensions/test_adk/test_service.py index fb3ada2bb..81c9d8776 100644 --- a/tests/unit/extensions/test_adk/test_service.py +++ b/tests/unit/extensions/test_adk/test_service.py @@ -45,6 +45,10 @@ def __init__(self) -> None: # Track calls to create_session self.create_session_calls: list[dict[str, Any]] = [] + self.upsert_app_state_calls: list[dict[str, Any]] = [] + self.upsert_user_state_calls: list[dict[str, Any]] = [] + self.app_state: dict[str, Any] | None = None + self.user_state: dict[str, Any] | None = None # Provide a get_session that returns a minimal session record self._session_record = { @@ -104,6 +108,20 @@ async def append_event(self, event_record: Any) -> None: async def get_events(self, *, session_id: str, after_timestamp: Any = None, limit: Any = None) -> list: return [] + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + return self.app_state + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + return self.user_state + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + self.upsert_app_state_calls.append({"app_name": app_name, "state": state}) + self.app_state = state + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + self.upsert_user_state_calls.append({"app_name": app_name, "user_id": user_id, "state": state}) + self.user_state = state + async def list_sessions(self, *, app_name: str, user_id: "str | None" = None) -> list: return [] @@ -154,6 +172,21 @@ async def test_get_session_forwards_renew_for_to_store() -> None: assert store.get_session_call_args[0] == {"session_id": "s1", "renew_for": renew_for} +@pytest.mark.anyio +async def test_get_session_merges_scoped_state_from_store() -> None: + """get_session returns the ADK merged state view.""" + store = MockStore() + store._session_record["state"] = {"session:key": "session"} + store.app_state = {"app:counter": 1} + store.user_state = {"user:theme": "dark"} + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + session = await service.get_session(app_name="app", user_id="u1", session_id="s1") + + assert session is not None + assert session.state == {"session:key": "session", "app:counter": 1, "user:theme": "dark"} + + # --------------------------------------------------------------------------- # append_event — calls append_event_and_update_state # --------------------------------------------------------------------------- @@ -247,6 +280,22 @@ async def test_append_event_strips_temp_state_delta_from_persisted_state() -> No assert persisted_state["regular"] == "updated" +@pytest.mark.anyio +async def test_append_event_routes_scoped_state_to_app_and_user_tables() -> None: + """app:* and user:* state is not stored in the per-session state blob.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session(state={"regular": "v0", "app:counter": 1, "user:theme": "light"}) + event = _make_event(state_delta={"regular": "v1", "app:counter": 2, "user:theme": "dark"}) + + await service.append_event(session, event) + + persisted_state = store.append_event_and_update_state_calls[-1]["state"] + assert persisted_state == {"regular": "v1"} + assert store.upsert_app_state_calls[-1] == {"app_name": "app", "state": {"app:counter": 2}} + assert store.upsert_user_state_calls[-1] == {"app_name": "app", "user_id": "u1", "state": {"user:theme": "dark"}} + + @pytest.mark.anyio async def test_append_event_skips_partial_events() -> None: """Partial events are not persisted to the store.""" @@ -323,7 +372,21 @@ async def test_create_session_strips_temp_keys_from_initial_state() -> None: persisted_state = store.create_session_calls[0]["state"] assert "temp:y" not in persisted_state assert persisted_state["x"] == 1 - assert persisted_state["app:z"] == 3 + assert "app:z" not in persisted_state + assert store.upsert_app_state_calls[-1] == {"app_name": "app", "state": {"app:z": 3}} + + +@pytest.mark.anyio +async def test_create_session_routes_initial_user_scoped_state() -> None: + """create_session writes user:* state to the user-scoped store.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + await service.create_session(app_name="app", user_id="u1", state={"x": 1, "user:theme": "dark"}) + + persisted_state = store.create_session_calls[0]["state"] + assert persisted_state == {"x": 1} + assert store.upsert_user_state_calls[-1] == {"app_name": "app", "user_id": "u1", "state": {"user:theme": "dark"}} @pytest.mark.anyio diff --git a/tests/unit/extensions/test_adk/test_store_config.py b/tests/unit/extensions/test_adk/test_store_config.py index ea047fb6b..f46c72542 100644 --- a/tests/unit/extensions/test_adk/test_store_config.py +++ b/tests/unit/extensions/test_adk/test_store_config.py @@ -83,6 +83,24 @@ async def delete_expired_events(self, before: datetime) -> int: async def delete_idle_sessions(self, updated_before: datetime) -> int: return 0 + async def get_app_state(self, app_name: str) -> dict[str, Any] | None: + return None + + async def get_user_state(self, app_name: str, user_id: str) -> dict[str, Any] | None: + return None + + async def upsert_app_state(self, app_name: str, state: dict[str, Any]) -> None: + return None + + async def upsert_user_state(self, app_name: str, user_id: str, state: dict[str, Any]) -> None: + return None + + async def get_metadata(self, key: str) -> str | None: + return None + + async def set_metadata(self, key: str, value: str) -> None: + return None + async def create_tables(self) -> None: return None @@ -92,10 +110,51 @@ async def _get_create_sessions_table_sql(self) -> str: async def _get_create_events_table_sql(self) -> str: return "" + async def _get_create_app_states_table_sql(self) -> str: + return "" + + async def _get_create_user_states_table_sql(self) -> str: + return "" + + async def _get_create_metadata_table_sql(self) -> str: + return "" + + async def _get_seed_metadata_sql(self) -> str: + return "" + + def _get_drop_app_states_table_sql(self) -> str: + return "" + + def _get_drop_user_states_table_sql(self) -> str: + return "" + + def _get_drop_metadata_table_sql(self) -> str: + return "" + def _get_drop_tables_sql(self) -> list[str]: return [] +class _MigrationSessionStore(_AsyncSessionStore): + async def _get_create_sessions_table_sql(self) -> str: + return "create sessions" + + async def _get_create_events_table_sql(self) -> str: + return "create events" + + async def _get_create_app_states_table_sql(self) -> str: + return "create app states" + + async def _get_create_user_states_table_sql(self) -> str: + return "create user states" + + async def _get_create_metadata_table_sql(self) -> str: + return "create metadata" + + async def _get_seed_metadata_sql(self) -> str: + return "seed metadata" + + class _AsyncMemoryStore(BaseAsyncADKMemoryStore[Any]): def __init__(self, config: _Config) -> None: super().__init__(config) @@ -181,6 +240,28 @@ def test_session_store_contract_declares_cleanup_hooks() -> None: assert "delete_idle_sessions" in BaseAsyncADKStore.__abstractmethods__ +def test_session_store_contract_declares_schema_parity_hooks() -> None: + expected_methods = { + "_get_create_app_states_table_sql", + "_get_create_user_states_table_sql", + "_get_create_metadata_table_sql", + "_get_drop_app_states_table_sql", + "_get_drop_user_states_table_sql", + "_get_drop_metadata_table_sql", + "_get_seed_metadata_sql", + "get_app_state", + "get_user_state", + "upsert_app_state", + "upsert_user_state", + "get_metadata", + "set_metadata", + } + + assert expected_methods <= BaseAsyncADKStore.__abstractmethods__ + for method_name in expected_methods: + assert inspect.getdoc(getattr(BaseAsyncADKStore, method_name)) + + def test_session_store_resolves_schema_parity_table_names() -> None: store = _AsyncSessionStore( _Config({ @@ -197,6 +278,36 @@ def test_session_store_resolves_schema_parity_table_names() -> None: assert store.metadata_table == "agent_metadata" +def test_session_store_uses_singular_default_table_names() -> None: + store = _AsyncSessionStore(_Config()) + + assert store.session_table == "adk_session" + assert store.events_table == "adk_event" + assert store.app_state_table == "adk_app_state" + assert store.user_state_table == "adk_user_state" + assert store.metadata_table == "adk_internal_metadata" + + +@pytest.mark.anyio +async def test_adk_migration_up_includes_schema_parity_tables(monkeypatch: pytest.MonkeyPatch) -> None: + migration = __import__("sqlspec.extensions.adk.migrations.0001_create_adk_tables", fromlist=["up"]) + context = type("MigrationContext", (), {"config": _Config()})() + + monkeypatch.setattr(migration, "_get_store_class", lambda _context: _MigrationSessionStore) + monkeypatch.setattr(migration, "_is_memory_enabled", lambda _context: False) + + statements = await migration.up(context) + + assert statements == [ + "create sessions", + "create events", + "create app states", + "create user states", + "create metadata", + "seed metadata", + ] + + @pytest.mark.parametrize("field", ["app_state_table", "user_state_table", "metadata_table"]) def test_session_store_validates_schema_parity_table_names(field: str) -> None: with pytest.raises(ValueError, match="Invalid table name"): diff --git a/tests/unit/extensions/test_adk/test_versioning.py b/tests/unit/extensions/test_adk/test_versioning.py index 5c0019ad0..1b43ef70a 100644 --- a/tests/unit/extensions/test_adk/test_versioning.py +++ b/tests/unit/extensions/test_adk/test_versioning.py @@ -38,6 +38,10 @@ def test_default_version_plan_matches_clean_break_v1_contract() -> None: ) +def test_schema_version_key_matches_official_adk_metadata_key() -> None: + assert ADK_SCHEMA_VERSION_KEY == "schema_version" + + def test_version_plan_metadata_items_include_schema_and_payload_versions() -> None: metadata_items = dict(_get_adk_version_plan(_Config({})).metadata_items()) diff --git a/tests/unit/extensions/test_events/test_events_config.py b/tests/unit/extensions/test_events/test_events_config.py index f304456b3..0c602d833 100644 --- a/tests/unit/extensions/test_events/test_events_config.py +++ b/tests/unit/extensions/test_events/test_events_config.py @@ -90,7 +90,7 @@ def test_adk_extension_auto_includes_migrations(tmp_path) -> None: config = SqliteConfig( connection_config={"database": str(tmp_path / "adk.db")}, migration_config={"script_location": "migrations"}, - extension_config={"adk": {"session_table": "adk_sessions"}}, + extension_config={"adk": {"session_table": "adk_session"}}, ) include_extensions = config.migration_config.get("include_extensions") diff --git a/tests/unit/utils/test_identifiers.py b/tests/unit/utils/test_identifiers.py index 30f962ae8..88063679b 100644 --- a/tests/unit/utils/test_identifiers.py +++ b/tests/unit/utils/test_identifiers.py @@ -7,7 +7,7 @@ def test_validate_identifier_returns_valid_name_unchanged() -> None: """Valid identifiers are returned unchanged.""" - assert validate_identifier("adk_sessions") == "adk_sessions" + assert validate_identifier("adk_session") == "adk_session" @pytest.mark.parametrize("name", ["", "1_table", "table-name", "table name", "foo; DROP TABLE x"]) @@ -28,22 +28,22 @@ def test_validate_identifier_rejects_names_longer_than_default_limit() -> None: def test_validate_identifier_rejects_schema_qualifier_by_default() -> None: """Schema-qualified names are rejected unless explicitly enabled.""" with pytest.raises(ValueError, match="Schema qualifier not allowed"): - validate_identifier("public.adk_sessions") + validate_identifier("public.adk_session") def test_validate_identifier_accepts_schema_qualified_name_when_enabled() -> None: """Schema-qualified names are validated segment by segment when enabled.""" - assert validate_identifier("public.adk_sessions", allow_schema_qualifier=True) == "public.adk_sessions" + assert validate_identifier("public.adk_session", allow_schema_qualifier=True) == "public.adk_session" def test_validate_identifier_accepts_multi_segment_qualified_name_when_enabled() -> None: """Existing event queue behavior accepts multi-segment qualified names.""" - name = "catalog.public.adk_events" + name = "catalog.public.adk_event" assert validate_identifier(name, allow_schema_qualifier=True) == name -@pytest.mark.parametrize("name", [".adk_sessions", "public.", "public..adk_sessions", "public.1_sessions"]) +@pytest.mark.parametrize("name", [".adk_session", "public.", "public..adk_session", "public.1_sessions"]) def test_validate_identifier_rejects_invalid_schema_qualified_segments(name: str) -> None: """Every schema-qualified segment must be a valid identifier.""" with pytest.raises(ValueError, match="Invalid identifier"): From 73b7dc4797807e9388d2c53d56106b113a936756 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 24 May 2026 15:18:11 +0000 Subject: [PATCH 17/29] feat(adk): atomic scoped-state writes in single transaction Extends BaseAsyncADKStore.append_event_and_update_state with app_name, user_id, app_state, user_state keyword arguments so events INSERT, sessions UPDATE, app_state UPSERT, and user_state UPSERT execute inside one adapter-local transaction. SQLSpecSessionService stops issuing separate upsert_app_state / upsert_user_state calls and routes everything through the unified atomic method, closing the silent correctness drift where a crash between calls left session state advanced but scoped state stale (Chapter 14 v8io acceptance). Shared contract helper assert_session_atomic_scoped_write_contract is exercised by every retained adapter test file that already used the scoped-state contract. --- sqlspec/adapters/adbc/adk/store.py | 75 +++++++-- sqlspec/adapters/aiomysql/adk/store.py | 39 ++++- sqlspec/adapters/aiosqlite/adk/store.py | 49 ++++-- sqlspec/adapters/asyncmy/adk/store.py | 39 ++++- sqlspec/adapters/asyncpg/adk/store.py | 34 +++- .../adapters/cockroach_asyncpg/adk/store.py | 34 +++- .../adapters/cockroach_psycopg/adk/store.py | 90 +++++++++- sqlspec/adapters/duckdb/adk/store.py | 75 +++++++-- sqlspec/adapters/mysqlconnector/adk/store.py | 100 +++++++++-- sqlspec/adapters/oracledb/adk/store.py | 159 ++++++++++++++---- sqlspec/adapters/psqlpy/adk/store.py | 34 +++- sqlspec/adapters/psycopg/adk/store.py | 94 ++++++++++- sqlspec/adapters/pymysql/adk/store.py | 61 +++++-- sqlspec/adapters/spanner/adk/store.py | 91 ++++++++-- sqlspec/adapters/sqlite/adk/store.py | 68 ++++++-- sqlspec/extensions/adk/service.py | 14 +- sqlspec/extensions/adk/store.py | 32 +++- .../adapters/_adk_contract_helpers.py | 70 +++++++- .../extensions/adk/test_session_operations.py | 6 + .../aiomysql/extensions/adk/test_store.py | 10 +- .../aiosqlite/extensions/adk/test_store.py | 10 ++ .../asyncmy/extensions/adk/test_store.py | 10 +- .../extensions/adk/test_session_operations.py | 10 +- .../duckdb/extensions/adk/test_store.py | 6 + .../extensions/adk/test_store.py | 12 +- .../spanner/extensions/adk/test_adk_store.py | 10 +- .../sqlite/extensions/adk/test_store.py | 10 ++ .../unit/extensions/test_adk/test_service.py | 30 +++- .../extensions/test_adk/test_store_config.py | 14 +- 29 files changed, 1095 insertions(+), 191 deletions(-) diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index c550aa08d..a35e2c18c 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -128,10 +128,26 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis return await async_(self._list_sessions)(app_name, user_id) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) + """Atomically append an event and update session + scoped state.""" + return await async_(self._append_event_and_update_state)( + event_record, + session_id, + state, + app_name=app_name, + user_id=user_id, + app_state=app_state, + user_state=user_state, + ) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None @@ -884,21 +900,22 @@ def _insert_event(self, event_record: "EventRecord") -> None: cursor.close() def _append_event_and_update_state( - self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + self, + event_record: "EventRecord", + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically insert an event and update the session's durable state. - - The event insert, state update, and refresh-SELECT are executed within - a single connection and committed together. ADBC drivers wrap a - variety of backends (Postgres, SQLite, DuckDB, ...) so we use a - SELECT-after-UPDATE rather than relying on RETURNING which not every - backend supports. + """Atomically insert an event and update session + scoped state. - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). + The event insert, state update, scoped-state upserts, and refresh-SELECT + are executed within a single connection and committed together. ADBC + drivers wrap a variety of backends (Postgres, SQLite, DuckDB, ...) so we + use DELETE+INSERT for upserts to remain portable across dialects. """ insert_sql = f""" INSERT INTO {self._events_table} ( @@ -915,8 +932,28 @@ def _append_event_and_update_state( FROM {self._session_table} WHERE id = ? """ + app_delete_sql = f"DELETE FROM {self._app_state_table} WHERE app_name = ?" + app_insert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + """ + user_delete_sql = f"DELETE FROM {self._user_state_table} WHERE app_name = ? AND user_id = ?" + user_insert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + """ + if app_state and app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + if user_state and (app_name is None or user_id is None): + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + state_json = self._serialize_state(state) event_data = self._serialize_json_field(event_record["event_data"]) + now = datetime.now(timezone.utc) + app_state_serialized = self._serialize_state(app_state) if app_state else None + user_state_serialized = self._serialize_state(user_state) if user_state else None with self._config.provide_connection() as conn: cursor = conn.cursor() @@ -934,6 +971,12 @@ def _append_event_and_update_state( cursor.execute(update_sql, (state_json, session_id)) cursor.execute(select_sql, (session_id,)) row = cursor.fetchone() + if app_state: + cursor.execute(app_delete_sql, (app_name,)) + cursor.execute(app_insert_sql, (app_name, app_state_serialized, now)) + if user_state: + cursor.execute(user_delete_sql, (app_name, user_id)) + cursor.execute(user_insert_sql, (app_name, user_id, user_state_serialized, now)) conn.commit() except Exception: with contextlib.suppress(Exception): diff --git a/sqlspec/adapters/aiomysql/adk/store.py b/sqlspec/adapters/aiomysql/adk/store.py index 8a81c5df2..2baa964ef 100644 --- a/sqlspec/adapters/aiomysql/adk/store.py +++ b/sqlspec/adapters/aiomysql/adk/store.py @@ -291,18 +291,21 @@ async def append_event(self, event_record: EventRecord) -> None: await conn.commit() async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state. + """Atomically append an event and update session + scoped state. MySQL doesn't support UPDATE...RETURNING; we follow the UPDATE with a SELECT inside the same transaction so callers get the refreshed row in a single round-trip pair (no separate connection acquisition). - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot. """ event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data @@ -326,6 +329,18 @@ async def append_event_and_update_state( WHERE id = %s """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + async with ( self._config.provide_connection() as conn, AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, @@ -343,6 +358,16 @@ async def append_event_and_update_state( await cursor.execute(update_sql, (state_json, session_id)) await cursor.execute(select_sql, (session_id,)) row = await cursor.fetchone() + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + await cursor.execute(app_upsert_sql, (app_name, to_json(app_state))) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + await cursor.execute(user_upsert_sql, (app_name, user_id, to_json(user_state))) await conn.commit() if row is None: diff --git a/sqlspec/adapters/aiosqlite/adk/store.py b/sqlspec/adapters/aiosqlite/adk/store.py index de182d9c1..82179e24d 100644 --- a/sqlspec/adapters/aiosqlite/adk/store.py +++ b/sqlspec/adapters/aiosqlite/adk/store.py @@ -350,20 +350,17 @@ async def append_event(self, event_record: EventRecord) -> None: await conn.commit() async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state. - - Inserts the event and updates the session state + update_time in a - single transaction. Both operations succeed or fail together. Returns - the updated SessionRecord via SQLite RETURNING (3.35+). - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (temp: keys already - stripped by the service layer). - """ + """Atomically append an event and update session + scoped state.""" import uuid timestamp_julian = _datetime_to_julian(event_record["timestamp"]) @@ -385,6 +382,22 @@ async def append_event_and_update_state( RETURNING id, app_name, user_id, state, create_time, update_time """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + ON CONFLICT(app_name) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + ON CONFLICT(app_name, user_id) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + async with self._config.provide_connection() as conn: await self._apply_pragmas(conn) await conn.execute( @@ -400,6 +413,16 @@ async def append_event_and_update_state( ) cursor = await conn.execute(update_sql, (state_json, now_julian, session_id)) row = await cursor.fetchone() + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + await conn.execute(app_upsert_sql, (app_name, to_json(app_state), now_julian)) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + await conn.execute(user_upsert_sql, (app_name, user_id, to_json(user_state), now_julian)) await conn.commit() if row is None: diff --git a/sqlspec/adapters/asyncmy/adk/store.py b/sqlspec/adapters/asyncmy/adk/store.py index 69709ae63..48e263819 100644 --- a/sqlspec/adapters/asyncmy/adk/store.py +++ b/sqlspec/adapters/asyncmy/adk/store.py @@ -272,18 +272,21 @@ async def append_event(self, event_record: EventRecord) -> None: await conn.commit() async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state. + """Atomically append an event and update session + scoped state. MySQL doesn't support UPDATE...RETURNING; we follow the UPDATE with a SELECT inside the same transaction so callers get the refreshed row in a single round-trip pair (no separate connection acquisition). - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot. """ event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data @@ -307,6 +310,18 @@ async def append_event_and_update_state( WHERE id = %s """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + async with self._config.provide_connection() as conn, conn.cursor() as cursor: await cursor.execute( insert_sql, @@ -321,6 +336,16 @@ async def append_event_and_update_state( await cursor.execute(update_sql, (state_json, session_id)) await cursor.execute(select_sql, (session_id,)) row = await cursor.fetchone() + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + await cursor.execute(app_upsert_sql, (app_name, to_json(app_state))) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + await cursor.execute(user_upsert_sql, (app_name, user_id, to_json(user_state))) await conn.commit() if row is None: diff --git a/sqlspec/adapters/asyncpg/adk/store.py b/sqlspec/adapters/asyncpg/adk/store.py index d26e174ee..72c62382d 100644 --- a/sqlspec/adapters/asyncpg/adk/store.py +++ b/sqlspec/adapters/asyncpg/adk/store.py @@ -180,7 +180,15 @@ async def append_event(self, event_record: EventRecord) -> None: ) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( @@ -193,6 +201,20 @@ async def append_event_and_update_state( WHERE id = $2 RETURNING id, app_name, user_id, state, create_time, update_time """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES ($1, $2, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES ($1, $2, $3, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ async with self._config.provide_connection() as conn, conn.transaction(): await conn.execute( @@ -204,6 +226,16 @@ async def append_event_and_update_state( event_record["event_data"], ) row = await conn.fetchrow(update_sql, state, session_id) + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + await conn.execute(app_upsert_sql, app_name, app_state) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + await conn.execute(user_upsert_sql, app_name, user_id, user_state) if row is None: msg = f"Session {session_id} not found during append_event_and_update_state." diff --git a/sqlspec/adapters/cockroach_asyncpg/adk/store.py b/sqlspec/adapters/cockroach_asyncpg/adk/store.py index 073497182..5d6d1c2bd 100644 --- a/sqlspec/adapters/cockroach_asyncpg/adk/store.py +++ b/sqlspec/adapters/cockroach_asyncpg/adk/store.py @@ -179,7 +179,15 @@ async def append_event(self, event_record: EventRecord) -> None: ) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( @@ -192,6 +200,20 @@ async def append_event_and_update_state( WHERE id = $2 RETURNING id, app_name, user_id, state, create_time, update_time """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES ($1, $2, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES ($1, $2, $3, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ async with self._config.provide_connection() as conn, conn.transaction(): await conn.execute( @@ -203,6 +225,16 @@ async def append_event_and_update_state( event_record["event_data"], ) row = await conn.fetchrow(update_sql, state, session_id) + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + await conn.execute(app_upsert_sql, app_name, app_state) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + await conn.execute(user_upsert_sql, app_name, user_id, user_state) if row is None: msg = f"Session {session_id} not found during append_event_and_update_state." diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py index 97436c80d..9980611a7 100644 --- a/sqlspec/adapters/cockroach_psycopg/adk/store.py +++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py @@ -234,7 +234,15 @@ async def append_event(self, event_record: EventRecord) -> None: await conn.commit() async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( @@ -247,6 +255,20 @@ async def append_event_and_update_state( WHERE id = %s RETURNING id, app_name, user_id, state, create_time, update_time """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ event_data_value = event_record["event_data"] jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value @@ -264,6 +286,16 @@ async def append_event_and_update_state( ) await cur.execute(update_sql.encode(), (Jsonb(state), session_id)) row = await cur.fetchone() + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + await cur.execute(app_upsert_sql.encode(), (app_name, Jsonb(app_state))) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + await cur.execute(user_upsert_sql.encode(), (app_name, user_id, Jsonb(user_state))) await conn.commit() if row is None: @@ -550,10 +582,26 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis return await async_(self._list_sessions)(app_name, user_id) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) + """Atomically append an event and update session + scoped state.""" + return await async_(self._append_event_and_update_state)( + event_record, + session_id, + state, + app_name=app_name, + user_id=user_id, + app_state=app_state, + user_state=user_state, + ) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None @@ -822,7 +870,15 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses return [] def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( @@ -835,6 +891,20 @@ def _append_event_and_update_state( WHERE id = %s RETURNING id, app_name, user_id, state, create_time, update_time """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ event_data_value = event_record["event_data"] jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value @@ -852,6 +922,16 @@ def _append_event_and_update_state( ) cur.execute(update_sql.encode(), (Jsonb(state), session_id)) row = cur.fetchone() + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + cur.execute(app_upsert_sql.encode(), (app_name, Jsonb(app_state))) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + cur.execute(user_upsert_sql.encode(), (app_name, user_id, Jsonb(user_state))) conn.commit() if row is None: diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py index 37e164c3d..2fd2f1d3e 100644 --- a/sqlspec/adapters/duckdb/adk/store.py +++ b/sqlspec/adapters/duckdb/adk/store.py @@ -184,21 +184,26 @@ async def append_event(self, event_record: EventRecord) -> None: await async_(self._append_event)(event_record) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state. - - The event insert and state update succeed together or fail together - within a single DuckDB transaction; the updated SessionRecord is - returned via UPDATE...RETURNING. - - Args: - event_record: Event record to store (5-key shape). - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). - """ - return await async_(self._append_event_and_update_state)(event_record, session_id, state) + """Atomically append an event and update session + scoped state.""" + return await async_(self._append_event_and_update_state)( + event_record, + session_id, + state, + app_name=app_name, + user_id=user_id, + app_state=app_state, + user_state=user_state, + ) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None @@ -589,7 +594,15 @@ def _append_event(self, event_record: EventRecord) -> None: conn.commit() def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Synchronous implementation of append_event_and_update_state.""" now = datetime.now(timezone.utc) @@ -609,6 +622,22 @@ def _append_event_and_update_state( RETURNING id, app_name, user_id, state, create_time, update_time """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + ON CONFLICT(app_name) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + ON CONFLICT(app_name, user_id) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + with self._config.provide_connection() as conn: cursor = conn.execute(update_sql, (state_json, now, session_id)) row = cursor.fetchone() @@ -623,17 +652,27 @@ def _append_event_and_update_state( event_data_str, ), ) + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + conn.execute(app_upsert_sql, (app_name, to_json(app_state), now)) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + conn.execute(user_upsert_sql, (app_name, user_id, to_json(user_state), now)) conn.commit() if row is None: msg = f"Session {session_id} not found during append_event_and_update_state." raise ValueError(msg) - session_id_val, app_name, user_id, state_data, create_time, update_time = row + session_id_val, row_app_name, row_user_id, state_data, create_time, update_time = row return SessionRecord( id=session_id_val, - app_name=app_name, - user_id=user_id, + app_name=row_app_name, + user_id=row_user_id, state=from_json(state_data) if isinstance(state_data, str) else state_data, create_time=create_time, update_time=update_time, diff --git a/sqlspec/adapters/mysqlconnector/adk/store.py b/sqlspec/adapters/mysqlconnector/adk/store.py index 374a4ef17..9b42ce861 100644 --- a/sqlspec/adapters/mysqlconnector/adk/store.py +++ b/sqlspec/adapters/mysqlconnector/adk/store.py @@ -340,18 +340,21 @@ async def append_event(self, event_record: EventRecord) -> None: await conn.commit() async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state. + """Atomically append an event and update session + scoped state. MySQL doesn't support UPDATE...RETURNING; the UPDATE is followed by a SELECT inside the same transaction so callers get the refreshed row without acquiring a second connection. - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot. """ event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data @@ -375,6 +378,18 @@ async def append_event_and_update_state( WHERE id = %s """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + async with self._config.provide_connection() as conn: cursor = await conn.cursor() try: @@ -391,6 +406,16 @@ async def append_event_and_update_state( await cursor.execute(update_sql, (state_json, session_id)) await cursor.execute(select_sql, (session_id,)) row = await cursor.fetchone() + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + await cursor.execute(app_upsert_sql, (app_name, to_json(app_state))) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + await cursor.execute(user_upsert_sql, (app_name, user_id, to_json(user_state))) finally: await cursor.close() await conn.commit() @@ -682,10 +707,26 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis return await async_(self._list_sessions)(app_name, user_id) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) + """Atomically append an event and update session + scoped state.""" + return await async_(self._append_event_and_update_state)( + event_record, + session_id, + state, + app_name=app_name, + user_id=user_id, + app_state=app_state, + user_state=user_state, + ) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None @@ -921,18 +962,21 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses raise def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically create an event and update the session's durable state. + """Atomically create an event and update session + scoped state. MySQL doesn't support UPDATE...RETURNING; the UPDATE is followed by a SELECT inside the same transaction so callers get the refreshed row without acquiring a second connection. - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot. """ event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data @@ -956,6 +1000,18 @@ def _append_event_and_update_state( WHERE id = %s """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + with self._config.provide_connection() as conn: cursor = conn.cursor() try: @@ -972,6 +1028,16 @@ def _append_event_and_update_state( cursor.execute(update_sql, (state_json, session_id)) cursor.execute(select_sql, (session_id,)) row = cursor.fetchone() + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + cursor.execute(app_upsert_sql, (app_name, to_json(app_state))) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + cursor.execute(user_upsert_sql, (app_name, user_id, to_json(user_state))) finally: cursor.close() conn.commit() diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index 1ae0a15e3..9a20d733c 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -379,22 +379,22 @@ async def append_event(self, event_record: EventRecord) -> None: await conn.commit() async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state. + """Atomically append an event and update session + scoped state. - Both the event insert and session state update are executed within a - single transaction so they succeed or fail together. The refreshed - SessionRecord is read inside the same transaction (Oracle's RETURNING - INTO requires output bind variables which complicate async cursor - handling, so a SELECT-after-UPDATE is used instead). - - Args: - event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_data. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). + All writes are executed within a single transaction so they succeed or + fail together. The refreshed SessionRecord is read inside the same + transaction (Oracle's RETURNING INTO requires output bind variables + which complicate async cursor handling, so SELECT-after-UPDATE is used). """ insert_sql = f""" INSERT INTO {self._events_table} ( @@ -417,6 +417,28 @@ async def append_event_and_update_state( WHERE id = :id """ + app_upsert_sql = f""" + MERGE INTO {self._app_state_table} target + USING (SELECT :app_name AS app_name, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, state, update_time) + VALUES (source.app_name, source.state, SYSTIMESTAMP) + """ + + user_upsert_sql = f""" + MERGE INTO {self._user_state_table} target + USING (SELECT :app_name AS app_name, :user_id AS user_id, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name AND target.user_id = source.user_id) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, user_id, state, update_time) + VALUES (source.app_name, source.user_id, source.state, SYSTIMESTAMP) + """ + async with self._config.provide_connection() as conn: cursor = conn.cursor() await cursor.execute( @@ -432,17 +454,32 @@ async def append_event_and_update_state( await cursor.execute(update_sql, {"state": state_data, "id": session_id}) await cursor.execute(select_sql, {"id": session_id}) row = await cursor.fetchone() + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + await cursor.execute( + app_upsert_sql, {"app_name": app_name, "state": await self._serialize_state(app_state)} + ) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + await cursor.execute( + user_upsert_sql, + {"app_name": app_name, "user_id": user_id, "state": await self._serialize_state(user_state)}, + ) await conn.commit() if row is None: msg = f"Session {session_id} not found during append_event_and_update_state." raise ValueError(msg) - session_id_val, app_name, user_id, state_data_row, create_time, update_time = row + session_id_val, row_app_name, row_user_id, state_data_row, create_time, update_time = row return SessionRecord( id=session_id_val, - app_name=app_name, - user_id=user_id, + app_name=row_app_name, + user_id=row_user_id, state=await self._deserialize_state(state_data_row), create_time=create_time, update_time=update_time, @@ -1147,10 +1184,26 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis return await async_(self._list_sessions)(app_name, user_id) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) + """Atomically append an event and update session + scoped state.""" + return await async_(self._append_event_and_update_state)( + event_record, + session_id, + state, + app_name=app_name, + user_id=user_id, + app_state=app_state, + user_state=user_state, + ) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None @@ -1848,20 +1901,21 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses raise def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically create an event and update the session's durable state. - - Both the event insert and session state update are executed within a - single transaction so they succeed or fail together; the refreshed - SessionRecord is read inside the same transaction. + """Atomically create an event and update session + scoped state. - Args: - event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_data. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). + All writes are executed within a single transaction so they succeed or + fail together; the refreshed SessionRecord is read inside the same + transaction. """ insert_sql = f""" INSERT INTO {self._events_table} ( @@ -1884,6 +1938,28 @@ def _append_event_and_update_state( WHERE id = :id """ + app_upsert_sql = f""" + MERGE INTO {self._app_state_table} target + USING (SELECT :app_name AS app_name, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, state, update_time) + VALUES (source.app_name, source.state, SYSTIMESTAMP) + """ + + user_upsert_sql = f""" + MERGE INTO {self._user_state_table} target + USING (SELECT :app_name AS app_name, :user_id AS user_id, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name AND target.user_id = source.user_id) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, user_id, state, update_time) + VALUES (source.app_name, source.user_id, source.state, SYSTIMESTAMP) + """ + with self._config.provide_connection() as conn: cursor = conn.cursor() cursor.execute( @@ -1899,17 +1975,30 @@ def _append_event_and_update_state( cursor.execute(update_sql, {"state": state_data, "id": session_id}) cursor.execute(select_sql, {"id": session_id}) row = cursor.fetchone() + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + cursor.execute(app_upsert_sql, {"app_name": app_name, "state": self._serialize_state(app_state)}) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + cursor.execute( + user_upsert_sql, + {"app_name": app_name, "user_id": user_id, "state": self._serialize_state(user_state)}, + ) conn.commit() if row is None: msg = f"Session {session_id} not found during append_event_and_update_state." raise ValueError(msg) - session_id_val, app_name, user_id, state_data_row, create_time, update_time = row + session_id_val, row_app_name, row_user_id, state_data_row, create_time, update_time = row return SessionRecord( id=session_id_val, - app_name=app_name, - user_id=user_id, + app_name=row_app_name, + user_id=row_user_id, state=self._deserialize_state(state_data_row), create_time=create_time, update_time=update_time, diff --git a/sqlspec/adapters/psqlpy/adk/store.py b/sqlspec/adapters/psqlpy/adk/store.py index e33b3dd86..25dac6d71 100644 --- a/sqlspec/adapters/psqlpy/adk/store.py +++ b/sqlspec/adapters/psqlpy/adk/store.py @@ -196,7 +196,15 @@ async def append_event(self, event_record: EventRecord) -> None: ) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( @@ -209,6 +217,20 @@ async def append_event_and_update_state( WHERE id = $2 RETURNING id, app_name, user_id, state, create_time, update_time """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES ($1, $2, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES ($1, $2, $3, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] await conn.execute( @@ -223,6 +245,16 @@ async def append_event_and_update_state( ) result = await conn.fetch(update_sql, [state, session_id]) rows: list[dict[str, Any]] = result.result() if result else [] + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + await conn.execute(app_upsert_sql, [app_name, app_state]) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + await conn.execute(user_upsert_sql, [app_name, user_id, user_state]) if not rows: msg = f"Session {session_id} not found during append_event_and_update_state." diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py index ee6c414a1..9d09f6168 100644 --- a/sqlspec/adapters/psycopg/adk/store.py +++ b/sqlspec/adapters/psycopg/adk/store.py @@ -227,7 +227,15 @@ async def append_event(self, event_record: EventRecord) -> None: ) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_query = pg_sql.SQL(""" INSERT INTO {table} ( @@ -242,6 +250,22 @@ async def append_event_and_update_state( RETURNING id, app_name, user_id, state, create_time, update_time """).format(table=pg_sql.Identifier(self._session_table)) + app_upsert_query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._app_state_table)) + + user_upsert_query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._user_state_table)) + event_data_value = event_record["event_data"] jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value @@ -258,6 +282,16 @@ async def append_event_and_update_state( ) await cur.execute(update_query, (Jsonb(state), session_id)) row = await cur.fetchone() + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + await cur.execute(app_upsert_query, (app_name, Jsonb(app_state))) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + await cur.execute(user_upsert_query, (app_name, user_id, Jsonb(user_state))) await conn.commit() if row is None: @@ -567,10 +601,26 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis return await async_(self._list_sessions)(app_name, user_id) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) + """Atomically append an event and update session + scoped state.""" + return await async_(self._append_event_and_update_state)( + event_record, + session_id, + state, + app_name=app_name, + user_id=user_id, + app_state=app_state, + user_state=user_state, + ) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None @@ -856,7 +906,15 @@ def _insert_event(self, event_record: EventRecord) -> None: conn.commit() def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_query = pg_sql.SQL(""" INSERT INTO {table} ( @@ -871,6 +929,22 @@ def _append_event_and_update_state( RETURNING id, app_name, user_id, state, create_time, update_time """).format(table=pg_sql.Identifier(self._session_table)) + app_upsert_query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._app_state_table)) + + user_upsert_query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._user_state_table)) + event_data_value = event_record["event_data"] jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value @@ -887,6 +961,16 @@ def _append_event_and_update_state( ) cur.execute(update_query, (Jsonb(state), session_id)) row = cur.fetchone() + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + cur.execute(app_upsert_query, (app_name, Jsonb(app_state))) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + cur.execute(user_upsert_query, (app_name, user_id, Jsonb(user_state))) conn.commit() if row is None: diff --git a/sqlspec/adapters/pymysql/adk/store.py b/sqlspec/adapters/pymysql/adk/store.py index c130aa59e..354d159f6 100644 --- a/sqlspec/adapters/pymysql/adk/store.py +++ b/sqlspec/adapters/pymysql/adk/store.py @@ -86,10 +86,26 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis return await async_(self._list_sessions)(app_name, user_id) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) + """Atomically append an event and update session + scoped state.""" + return await async_(self._append_event_and_update_state)( + event_record, + session_id, + state, + app_name=app_name, + user_id=user_id, + app_state=app_state, + user_state=user_state, + ) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None @@ -389,18 +405,21 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses raise def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically create an event and update the session's durable state. + """Atomically create an event and update session + scoped state. MySQL doesn't support UPDATE...RETURNING; the UPDATE is followed by a SELECT inside the same transaction so callers get the refreshed row without acquiring a second connection. - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot. """ event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data @@ -424,6 +443,18 @@ def _append_event_and_update_state( WHERE id = %s """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + with self._config.provide_connection() as conn: cursor = conn.cursor() try: @@ -440,6 +471,16 @@ def _append_event_and_update_state( cursor.execute(update_sql, (state_json, session_id)) cursor.execute(select_sql, (session_id,)) row = cursor.fetchone() + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + cursor.execute(app_upsert_sql, (app_name, to_json(app_state))) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + cursor.execute(user_upsert_sql, (app_name, user_id, to_json(user_state))) finally: cursor.close() conn.commit() diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index ea6dad191..c74c0538d 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -66,10 +66,26 @@ async def delete_session(self, session_id: str) -> None: await async_(self._delete_session)(session_id) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) + """Atomically append an event and update session + scoped state.""" + return await async_(self._append_event_and_update_state)( + event_record, + session_id, + state, + app_name=app_name, + user_id=user_id, + app_state=app_state, + user_state=user_state, + ) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None @@ -283,20 +299,22 @@ def _delete_session(self, session_id: str) -> None: self._run_write([(delete_events_sql, params, types), (delete_session_sql, params, types)]) def _append_event_and_update_state( - self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + self, + event_record: "EventRecord", + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically insert an event and update session state in one transaction. - - Both the event INSERT and the session state UPDATE execute within a single - Spanner transaction so they succeed or fail together. A follow-up - single-use read returns the SessionRecord; we can't capture update_time - inside the write txn because PENDING_COMMIT_TIMESTAMP() only materialises - on commit. - - Args: - event_record: Event record to store. - session_id: Session whose state should be updated. - state: Post-append durable state snapshot. + """Atomically insert event + update session + upsert scoped state. + + All writes execute within a single Spanner transaction so they succeed + or fail together. A follow-up single-use read returns the SessionRecord; + we can't capture update_time inside the write txn because + PENDING_COMMIT_TIMESTAMP() only materialises on commit. """ event_params: dict[str, Any] = { "session_id": event_record["session_id"], @@ -320,10 +338,47 @@ def _append_event_and_update_state( if self._shard_count > 1: update_sql = f"{update_sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" - self._run_write([ + statements: list[tuple[str, dict[str, Any], dict[str, Any]]] = [ (insert_sql, event_params, self._event_param_types()), (update_sql, state_params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type}), - ]) + ] + + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + app_delete_sql = f"DELETE FROM {self._app_state_table} WHERE app_name = @app_name" + app_insert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (@app_name, @state, PENDING_COMMIT_TIMESTAMP()) + """ + statements.append((app_delete_sql, {"app_name": app_name}, {"app_name": SPANNER_PARAM_TYPES.STRING})) + statements.append(( + app_insert_sql, + {"app_name": app_name, "state": to_json(app_state)}, + {"app_name": SPANNER_PARAM_TYPES.STRING, "state": json_type}, + )) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + user_delete_sql = f"DELETE FROM {self._user_state_table} WHERE app_name = @app_name AND user_id = @user_id" + user_insert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (@app_name, @user_id, @state, PENDING_COMMIT_TIMESTAMP()) + """ + statements.append(( + user_delete_sql, + {"app_name": app_name, "user_id": user_id}, + {"app_name": SPANNER_PARAM_TYPES.STRING, "user_id": SPANNER_PARAM_TYPES.STRING}, + )) + statements.append(( + user_insert_sql, + {"app_name": app_name, "user_id": user_id, "state": to_json(user_state)}, + {"app_name": SPANNER_PARAM_TYPES.STRING, "user_id": SPANNER_PARAM_TYPES.STRING, "state": json_type}, + )) + + self._run_write(statements) record = self._get_session(session_id) if record is None: diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py index 3db5e2b72..1ee8049b2 100644 --- a/sqlspec/adapters/sqlite/adk/store.py +++ b/sqlspec/adapters/sqlite/adk/store.py @@ -217,20 +217,26 @@ async def append_event(self, event_record: EventRecord) -> None: await async_(self._append_event)(event_record) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state. - - Inserts the event and updates the session state + update_time in a - single transaction, returning the updated SessionRecord via RETURNING. - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (temp: keys already - stripped by the service layer). - """ - return await async_(self._append_event_and_update_state)(event_record, session_id, state) + """Atomically append an event and update session + scoped state.""" + return await async_(self._append_event_and_update_state)( + event_record, + session_id, + state, + app_name=app_name, + user_id=user_id, + app_state=app_state, + user_state=user_state, + ) async def get_events( self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None @@ -601,7 +607,15 @@ def _append_event(self, event_record: EventRecord) -> None: conn.commit() def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Synchronous implementation of append_event_and_update_state.""" import uuid @@ -625,6 +639,22 @@ def _append_event_and_update_state( RETURNING id, app_name, user_id, state, create_time, update_time """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + ON CONFLICT(app_name) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + ON CONFLICT(app_name, user_id) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + with self._config.provide_connection() as conn: self._apply_pragmas(conn) conn.execute( @@ -640,6 +670,16 @@ def _append_event_and_update_state( ) cursor = conn.execute(update_sql, (state_json, now_julian, session_id)) row = cursor.fetchone() + if app_state: + if app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + conn.execute(app_upsert_sql, (app_name, to_json(app_state), now_julian)) + if user_state: + if app_name is None or user_id is None: + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + conn.execute(user_upsert_sql, (app_name, user_id, to_json(user_state), now_julian)) conn.commit() if row is None: diff --git a/sqlspec/extensions/adk/service.py b/sqlspec/extensions/adk/service.py index d0ef811c4..26c4aaaa0 100644 --- a/sqlspec/extensions/adk/service.py +++ b/sqlspec/extensions/adk/service.py @@ -279,14 +279,16 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": ) raise ValueError(msg) - # --- Persist event and state atomically --- + # --- Persist event and all scoped state atomically --- updated_record = await self._store.append_event_and_update_state( - event_record=event_record, session_id=session.id, state=session_state + event_record=event_record, + session_id=session.id, + state=session_state, + app_name=session.app_name, + user_id=session.user_id, + app_state=app_state or None, + user_state=user_state or None, ) - if app_state: - await self._store.upsert_app_state(session.app_name, app_state) - if user_state: - await self._store.upsert_user_state(session.app_name, session.user_id, user_state) updated_record["state"] = merge_scoped_state(updated_record["state"], app_state, user_state) # Use the returned record directly — saves a round-trip vs a follow-up get_session(). diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index e0509a532..ff1fe1d8b 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -282,20 +282,40 @@ async def append_event(self, event_record: "EventRecord") -> None: @abstractmethod async def append_event_and_update_state( - self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + self, + event_record: "EventRecord", + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> "SessionRecord": """Atomically append an event and update the session's durable state. This is the authoritative durable write boundary for post-creation - session mutations. The event insert and state update must succeed - together or fail together, and the updated session record is returned - in the same round-trip so callers don't need a follow-up read. + session mutations. The event insert, session state update, and the + optional scoped-state upserts must succeed together or fail together, + and the updated session record is returned in the same round-trip so + callers don't need a follow-up read. + + When ``app_state`` is provided (non-None), it is upserted into the + ``app_state_table`` for ``app_name``. When ``user_state`` is provided, + it is upserted into the ``user_state_table`` for ``(app_name, user_id)``. + Empty dicts are treated as "no scoped delta" and skipped. Args: event_record: Event record to store. session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). + state: Post-append durable session-scoped state snapshot + (``temp:`` keys already stripped by the service layer). + app_name: Application name for routing scoped-state upserts. Required + when ``app_state`` or ``user_state`` is non-empty. + user_id: User identifier for routing user-scoped upserts. Required + when ``user_state`` is non-empty. + app_state: App-scoped state delta (``app:*`` keys) to upsert atomically. + user_state: User-scoped state delta (``user:*`` keys) to upsert atomically. Returns: The updated SessionRecord reflecting the new state and update_time. diff --git a/tests/integration/adapters/_adk_contract_helpers.py b/tests/integration/adapters/_adk_contract_helpers.py index cab6cda8d..d3ab5e8a2 100644 --- a/tests/integration/adapters/_adk_contract_helpers.py +++ b/tests/integration/adapters/_adk_contract_helpers.py @@ -10,6 +10,7 @@ __all__ = ( "assert_memory_store_contract", + "assert_session_atomic_scoped_write_contract", "assert_session_event_cleanup_contract", "assert_session_event_store_contract", "assert_session_get_session_renewal_contract", @@ -38,7 +39,15 @@ async def delete_session(self, session_id: str) -> None: ... async def append_event(self, event_record: EventRecord) -> None: ... async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: dict[str, object] + self, + event_record: EventRecord, + session_id: str, + state: dict[str, object], + *, + app_name: str | None = None, + user_id: str | None = None, + app_state: dict[str, object] | None = None, + user_state: dict[str, object] | None = None, ) -> SessionRecord: ... async def get_events( @@ -306,6 +315,65 @@ async def assert_session_scoped_state_contract(store: SessionEventStore, *, mark assert fetched_other_user.state == {"app:counter": 1} +async def assert_session_atomic_scoped_write_contract(store: SessionEventStore, *, marker: str) -> None: + """Assert append_event_and_update_state accepts scoped-state kwargs. + + Verifies the store-level atomic write delivers events INSERT + sessions UPDATE + + app_state UPSERT + user_state UPSERT in a single round-trip when callers + supply ``app_state`` and ``user_state`` alongside the session state snapshot. + """ + app_name = _contract_key(marker, "atomic-app") + user_id = _contract_key(marker, "atomic-user") + session_id = _contract_key(marker, "atomic-session") + no_scope_session_id = _contract_key(marker, "atomic-no-scope-session") + base_time = datetime(2026, 5, 24, 12, 0, tzinfo=timezone.utc) + + await store.create_session(session_id, app_name, user_id, {"initial": 0}) + + event = _event_record( + session_id=session_id, + event_id="atomic-event-1", + invocation_id="atomic-inv-1", + author="user", + timestamp=base_time, + event_data={"actions": {"state_delta": {"app:counter": 1, "user:theme": "dark", "turn": 1}}}, + ) + + updated = await store.append_event_and_update_state( + event, + session_id, + {"turn": 1}, + app_name=app_name, + user_id=user_id, + app_state={"app:counter": 1}, + user_state={"user:theme": "dark"}, + ) + + assert updated["state"] == {"turn": 1} + assert updated["id"] == session_id + assert await store.get_app_state(app_name) == {"app:counter": 1} + assert await store.get_user_state(app_name, user_id) == {"user:theme": "dark"} + stored_events = await store.get_events(session_id) + assert any(record["invocation_id"] == "atomic-inv-1" for record in stored_events) + + await store.create_session(no_scope_session_id, app_name, user_id, {"phase": 0}) + no_scope_event = _event_record( + session_id=no_scope_session_id, + event_id="atomic-event-2", + invocation_id="atomic-inv-2", + author="model", + timestamp=base_time + timedelta(seconds=1), + event_data={"content": {"parts": [{"text": "no scope delta"}]}}, + ) + no_scope_update = await store.append_event_and_update_state( + no_scope_event, no_scope_session_id, {"phase": 1}, app_name=app_name, user_id=user_id + ) + assert no_scope_update["state"] == {"phase": 1} + # Skipped scoped writes leave existing app/user state untouched. + assert await store.get_app_state(app_name) == {"app:counter": 1} + assert await store.get_user_state(app_name, user_id) == {"user:theme": "dark"} + + async def assert_session_table_lifecycle_contract(store: SessionEventStore, *, marker: str) -> None: """Assert ADK stores can drop and recreate their managed session tables.""" app_name = _contract_key(marker, "lifecycle-app") diff --git a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py index 74c3ee353..7f2daf644 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py @@ -8,6 +8,7 @@ from sqlspec.adapters.adbc import AdbcConfig from sqlspec.adapters.adbc.adk import AdbcADKStore from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, assert_session_event_cleanup_contract, assert_session_get_session_renewal_contract, assert_session_scoped_state_contract, @@ -117,6 +118,11 @@ async def test_session_table_lifecycle_contract(adbc_store: Any) -> None: await assert_session_table_lifecycle_contract(adbc_store, marker="adbc") +async def test_session_atomic_scoped_write_contract(adbc_store: Any) -> None: + """ADBC routes scoped-state upserts inside the append/update transaction.""" + await assert_session_atomic_scoped_write_contract(adbc_store, marker="adbc") + + async def test_list_sessions(adbc_store: Any) -> None: """Test listing sessions for an app and user.""" app_name = "test-app" diff --git a/tests/integration/adapters/aiomysql/extensions/adk/test_store.py b/tests/integration/adapters/aiomysql/extensions/adk/test_store.py index a162ced06..61aee3bdf 100644 --- a/tests/integration/adapters/aiomysql/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiomysql/extensions/adk/test_store.py @@ -8,7 +8,10 @@ from sqlspec.adapters.aiomysql._typing import AiomysqlCursor from sqlspec.adapters.aiomysql.adk import AiomysqlADKStore from sqlspec.extensions.adk import EventRecord -from tests.integration.adapters._adk_contract_helpers import assert_session_scoped_state_contract +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, + assert_session_scoped_state_contract, +) pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.aiomysql, pytest.mark.integration] @@ -24,6 +27,11 @@ async def test_aiomysql_session_scoped_state_contract(aiomysql_adk_store: Aiomys await assert_session_scoped_state_contract(aiomysql_adk_store, marker="aiomysql") +async def test_aiomysql_session_atomic_scoped_write_contract(aiomysql_adk_store: AiomysqlADKStore) -> None: + """Aiomysql routes scoped-state upserts inside the append/update transaction.""" + await assert_session_atomic_scoped_write_contract(aiomysql_adk_store, marker="aiomysql") + + async def test_storage_types_verification(aiomysql_adk_store: AiomysqlADKStore) -> None: """Verify MySQL uses JSON type (not TEXT) and TIMESTAMP(6) for microseconds. diff --git a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py index 81d1e9a41..20965e187 100644 --- a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py @@ -9,6 +9,7 @@ from sqlspec.adapters.aiosqlite.adk import AiosqliteADKStore from sqlspec.extensions.adk import EventRecord from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, assert_session_event_cleanup_contract, assert_session_event_store_contract, assert_session_get_session_renewal_contract, @@ -106,6 +107,15 @@ async def test_aiosqlite_session_table_lifecycle_contract(tmp_path: Path) -> Non await config.close_pool() +async def test_aiosqlite_session_atomic_scoped_write_contract(tmp_path: Path) -> None: + """AioSQLite routes scoped-state upserts inside the append/update transaction.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_atomic_scoped_write_contract(store, marker="aiosqlite") + finally: + await config.close_pool() + + async def test_aiosqlite_append_event_and_update_state_is_atomic_contract(tmp_path: Path) -> None: """Event append and durable state update happen through the clean-break method.""" config, store = await _build_store(tmp_path) diff --git a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py index 92ce3400c..f9453d6bd 100644 --- a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py +++ b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py @@ -7,7 +7,10 @@ from sqlspec.adapters.asyncmy.adk import AsyncmyADKStore from sqlspec.extensions.adk import EventRecord -from tests.integration.adapters._adk_contract_helpers import assert_session_scoped_state_contract +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, + assert_session_scoped_state_contract, +) pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.asyncmy, pytest.mark.integration] @@ -23,6 +26,11 @@ async def test_asyncmy_session_scoped_state_contract(asyncmy_adk_store: AsyncmyA await assert_session_scoped_state_contract(asyncmy_adk_store, marker="asyncmy") +async def test_asyncmy_session_atomic_scoped_write_contract(asyncmy_adk_store: AsyncmyADKStore) -> None: + """Asyncmy routes scoped-state upserts inside the append/update transaction.""" + await assert_session_atomic_scoped_write_contract(asyncmy_adk_store, marker="asyncmy") + + async def test_storage_types_verification(asyncmy_adk_store: AsyncmyADKStore) -> None: """Verify MySQL uses JSON type (not TEXT) and TIMESTAMP(6) for microseconds. diff --git a/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py b/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py index 435765823..83944ee5b 100644 --- a/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py +++ b/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py @@ -4,7 +4,10 @@ import pytest -from tests.integration.adapters._adk_contract_helpers import assert_session_scoped_state_contract +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, + assert_session_scoped_state_contract, +) pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.asyncpg, pytest.mark.integration] @@ -47,6 +50,11 @@ async def test_asyncpg_session_scoped_state_contract(asyncpg_adk_store: Any) -> await assert_session_scoped_state_contract(asyncpg_adk_store, marker="asyncpg") +async def test_asyncpg_session_atomic_scoped_write_contract(asyncpg_adk_store: Any) -> None: + """Asyncpg routes scoped-state upserts inside the append/update transaction.""" + await assert_session_atomic_scoped_write_contract(asyncpg_adk_store, marker="asyncpg") + + async def test_get_nonexistent_session(asyncpg_adk_store: Any) -> None: """Test retrieving a session that doesn't exist.""" result = await asyncpg_adk_store.get_session("nonexistent") diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index ae2e7dee3..118f7cd98 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -11,6 +11,7 @@ from sqlspec.adapters.duckdb.config import DuckDBConfig from sqlspec.extensions.adk import EventRecord from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, assert_session_event_cleanup_contract, assert_session_event_store_contract, assert_session_get_session_renewal_contract, @@ -79,6 +80,11 @@ async def test_duckdb_session_table_lifecycle_contract(duckdb_adk_store: DuckdbA await assert_session_table_lifecycle_contract(duckdb_adk_store, marker="duckdb") +async def test_duckdb_session_atomic_scoped_write_contract(duckdb_adk_store: DuckdbADKStore) -> None: + """DuckDB routes scoped-state upserts inside the append/update transaction.""" + await assert_session_atomic_scoped_write_contract(duckdb_adk_store, marker="duckdb") + + async def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: """Test creating and retrieving a session.""" session_id = "session-001" diff --git a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py index ba288bd1f..399d24b80 100644 --- a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py +++ b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py @@ -8,7 +8,10 @@ from sqlspec.adapters.mysqlconnector.adk import MysqlConnectorAsyncADKStore from sqlspec.extensions.adk import EventRecord -from tests.integration.adapters._adk_contract_helpers import assert_session_scoped_state_contract +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, + assert_session_scoped_state_contract, +) pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.mysql_connector, pytest.mark.integration] @@ -26,6 +29,13 @@ async def test_mysqlconnector_session_scoped_state_contract( await assert_session_scoped_state_contract(mysqlconnector_adk_store, marker="mysqlconnector") +async def test_mysqlconnector_session_atomic_scoped_write_contract( + mysqlconnector_adk_store: MysqlConnectorAsyncADKStore, +) -> None: + """MysqlConnector routes scoped-state upserts inside the append/update transaction.""" + await assert_session_atomic_scoped_write_contract(mysqlconnector_adk_store, marker="mysqlconnector") + + async def test_storage_types_verification(mysqlconnector_adk_store: MysqlConnectorAsyncADKStore) -> None: """Verify MySQL uses JSON type (not TEXT) and TIMESTAMP(6) for microseconds.""" config = mysqlconnector_adk_store.config diff --git a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py index 29dcf2a53..5e5d2e4e2 100644 --- a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py +++ b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py @@ -7,7 +7,10 @@ import pytest from sqlspec.extensions.adk import EventRecord -from tests.integration.adapters._adk_contract_helpers import assert_session_scoped_state_contract +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, + assert_session_scoped_state_contract, +) pytestmark = [pytest.mark.spanner, pytest.mark.integration] @@ -28,6 +31,11 @@ async def test_spanner_session_scoped_state_contract(spanner_adk_store: Any) -> await assert_session_scoped_state_contract(spanner_adk_store, marker="spanner") +async def test_spanner_session_atomic_scoped_write_contract(spanner_adk_store: Any) -> None: + """Spanner routes scoped-state upserts inside the append/update transaction.""" + await assert_session_atomic_scoped_write_contract(spanner_adk_store, marker="spanner") + + async def test_update_session_state(spanner_adk_store: Any) -> None: session_id = "session-update" await spanner_adk_store.delete_session(session_id) diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_store.py index 9cc86efc2..bb88c78b7 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_store.py @@ -9,6 +9,7 @@ from sqlspec.adapters.sqlite.adk import SqliteADKStore from sqlspec.extensions.adk import EventRecord from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, assert_session_event_cleanup_contract, assert_session_event_store_contract, assert_session_get_session_renewal_contract, @@ -86,6 +87,15 @@ async def test_sqlite_session_table_lifecycle_contract(tmp_path: Path) -> None: config.close_pool() +async def test_sqlite_session_atomic_scoped_write_contract(tmp_path: Path) -> None: + """SQLite ADK store routes scoped-state upserts inside the append/update transaction.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_atomic_scoped_write_contract(store, marker="sqlite") + finally: + config.close_pool() + + async def test_sqlite_append_event_and_update_state_is_atomic_contract(tmp_path: Path) -> None: """Event append and durable state update happen through the clean-break method.""" config, store = await _build_store(tmp_path) diff --git a/tests/unit/extensions/test_adk/test_service.py b/tests/unit/extensions/test_adk/test_service.py index 81c9d8776..e65b3b0ed 100644 --- a/tests/unit/extensions/test_adk/test_service.py +++ b/tests/unit/extensions/test_adk/test_service.py @@ -61,14 +61,32 @@ def __init__(self) -> None: } async def append_event_and_update_state( - self, event_record: Any, session_id: str, state: "dict[str, Any]" + self, + event_record: Any, + session_id: str, + state: "dict[str, Any]", + *, + app_name: str | None = None, + user_id: str | None = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> "dict[str, Any]": self.append_event_and_update_state_called = True self.append_event_and_update_state_calls.append({ "event_record": event_record, "session_id": session_id, "state": state, + "app_name": app_name, + "user_id": user_id, + "app_state": app_state, + "user_state": user_state, }) + if app_state: + self.upsert_app_state_calls.append({"app_name": app_name, "state": app_state}) + self.app_state = app_state + if user_state: + self.upsert_user_state_calls.append({"app_name": app_name, "user_id": user_id, "state": user_state}) + self.user_state = user_state # Return the updated SessionRecord — caller no longer needs a follow-up get_session(). updated = dict(self._session_record) updated["state"] = state @@ -540,7 +558,15 @@ async def test_append_event_updates_inmemory_after_persist() -> None: class FailingStore(MockStore): async def append_event_and_update_state( - self, event_record: Any, session_id: str, state: Any + self, + event_record: Any, + session_id: str, + state: Any, + *, + app_name: str | None = None, + user_id: str | None = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> "dict[str, Any]": raise RuntimeError("Simulated DB failure") diff --git a/tests/unit/extensions/test_adk/test_store_config.py b/tests/unit/extensions/test_adk/test_store_config.py index f46c72542..18fa217f5 100644 --- a/tests/unit/extensions/test_adk/test_store_config.py +++ b/tests/unit/extensions/test_adk/test_store_config.py @@ -61,12 +61,20 @@ async def append_event(self, event_record: EventRecord) -> None: return None async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: dict[str, Any] + self, + event_record: EventRecord, + session_id: str, + state: dict[str, Any], + *, + app_name: str | None = None, + user_id: str | None = None, + app_state: dict[str, Any] | None = None, + user_state: dict[str, Any] | None = None, ) -> SessionRecord: return SessionRecord( id=session_id, - app_name="test-app", - user_id="test-user", + app_name=app_name or "test-app", + user_id=user_id or "test-user", state=state, create_time=datetime.now(), update_time=datetime.now(), From 198f0ab481254556071e2cc7ef16d63d08da4699 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 24 May 2026 15:19:21 +0000 Subject: [PATCH 18/29] test(adk): add clean-break grep guards + document recreate_tables cutover - New tests/unit/extensions/test_adk/test_clean_break_guards.py fails CI if event_json reappears in any ADK source or per-adapter adk/ module, and if any backwards_compat / legacy_ / # DEPRECATED markers leak into the extension surface (PRD Global Constraint #1). - docs/extensions/adk/migrations.rst now documents the programmatic recreate_tables() / drop_tables() cutover path (Revision 6 Q8). --- docs/extensions/adk/migrations.rst | 17 +++++ .../test_adk/test_clean_break_guards.py | 71 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 tests/unit/extensions/test_adk/test_clean_break_guards.py diff --git a/docs/extensions/adk/migrations.rst b/docs/extensions/adk/migrations.rst index 20fae085f..630fbc6bc 100644 --- a/docs/extensions/adk/migrations.rst +++ b/docs/extensions/adk/migrations.rst @@ -49,6 +49,23 @@ through your deployment workflow. by extension name. Use ``migration_config={"enabled": False}`` to disable migrations entirely for a given database config. +Programmatic Cutover with ``recreate_tables()`` +================================================ + +The base store exposes ``await store.recreate_tables()``, which drops every +ADK-managed table in FK-safe order and recreates them from the current DDL. +This is the supported in-place cutover for deployments that own the database +end-to-end and can tolerate a full ADK schema reset: + +.. code-block:: python + + await store.recreate_tables() + +``recreate_tables()`` does not touch the SQLSpec migrations runner state, so +the next ``sqlspec upgrade`` run still sees the unchanged migration history. +Use ``await store.drop_tables()`` if you need to remove the schema without +rebuilding it. + Clean-Break Migration Notes ============================ diff --git a/tests/unit/extensions/test_adk/test_clean_break_guards.py b/tests/unit/extensions/test_adk/test_clean_break_guards.py new file mode 100644 index 000000000..888a4f512 --- /dev/null +++ b/tests/unit/extensions/test_adk/test_clean_break_guards.py @@ -0,0 +1,71 @@ +"""Clean-break guard tests for the ADK extension. + +Two forbidden patterns are policed here: + +1. ``event_json`` — Revision 6 Q2 renamed every ADK adapter's events payload + column to ``event_data`` to match the upstream Google ADK schema. Any + re-introduction of the older name in shipped code is treated as a + regression. +2. Compatibility shim markers — PRD Global Constraint #1 prohibits shims in + the ADK clean break. Patterns like ``backwards_compat``, ``legacy_``, and + ``# DEPRECATED`` must not leak into ``sqlspec/extensions/adk/`` or + per-adapter ``adk/`` modules. +""" + +import re +from pathlib import Path + +ADK_ROOTS = ( + Path(__file__).parents[4] / "sqlspec" / "extensions" / "adk", + Path(__file__).parents[4] / "sqlspec" / "adapters", +) + + +def _iter_adk_sources() -> "list[Path]": + files: list[Path] = [] + for root in ADK_ROOTS: + if not root.exists(): + continue + for path in root.rglob("*.py"): + parts = path.parts + if "adk" not in parts: + continue + if "__pycache__" in parts: + continue + files.append(path) + return files + + +def test_no_event_json_references() -> None: + """The clean-break events payload column is named event_data everywhere.""" + pattern = re.compile(r"\bevent_json\b") + offenders: list[str] = [] + for path in _iter_adk_sources(): + contents = path.read_text(encoding="utf-8") + if pattern.search(contents): + offenders.append(str(path)) + assert not offenders, ( + "event_json column name reintroduced in ADK sources — rename to event_data.\n" + f"Offending files:\n - " + "\n - ".join(offenders) + ) + + +def test_no_compat_shim_markers() -> None: + """The ADK clean break forbids backwards-compatibility shims and deprecation markers.""" + forbidden_patterns = ( + re.compile(r"backwards?_compat", re.IGNORECASE), + re.compile(r"\blegacy_"), + re.compile(r"#\s*DEPRECATED"), + ) + offenders: list[str] = [] + for path in _iter_adk_sources(): + contents = path.read_text(encoding="utf-8") + for pattern in forbidden_patterns: + match = pattern.search(contents) + if match: + offenders.append(f"{path} (matched {match.group(0)!r})") + break + assert not offenders, ( + "Compat shim markers detected in ADK sources — PRD Global Constraint #1 forbids them.\n" + f"Offending files:\n - " + "\n - ".join(offenders) + ) From de3ed0aabd8fdd2b04a8afaacca8457004b066cd Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 24 May 2026 15:25:30 +0000 Subject: [PATCH 19/29] feat(adk): add BigQuery analytics-replica store Revives the BigQuery ADK store per Chapter 9 (sqlspec-9k0). The store wraps BigQuery's sync client with async_(), uses native JSON columns, DATE partitioning + clustering for sessions and events, and MERGE for scoped-state upserts. ADKConfig.bigquery.session_lookup_window_days bounds list-sessions partition scans, and ADKRetentionConfig.event_ttl_seconds is converted to partition_expiration_days on the events table. The store is positioned as the analytics-replica path: BigQuery DML is seconds-latency and not transactional across statements, so production deployments should pair it with an OLTP-grade ADK adapter for live state. Unit tests in tests/unit/adapters/test_bigquery_adk.py cover instantiation, DDL shape, partitioning, clustering, and idempotent metadata seeding. The support matrix in docs/extensions/adk/backends.rst documents the new analytics-replica status and the BigQuery-specific configuration knobs. --- docs/extensions/adk/backends.rst | 42 +- sqlspec/adapters/bigquery/adk/__init__.py | 5 + sqlspec/adapters/bigquery/adk/store.py | 576 ++++++++++++++++++ tests/unit/adapters/test_bigquery_adk.py | 94 +++ .../test_adk/test_clean_break_guards.py | 4 +- 5 files changed, 711 insertions(+), 10 deletions(-) create mode 100644 sqlspec/adapters/bigquery/adk/__init__.py create mode 100644 sqlspec/adapters/bigquery/adk/store.py create mode 100644 tests/unit/adapters/test_bigquery_adk.py diff --git a/docs/extensions/adk/backends.rst b/docs/extensions/adk/backends.rst index 577c522a5..37af698f5 100644 --- a/docs/extensions/adk/backends.rst +++ b/docs/extensions/adk/backends.rst @@ -98,6 +98,12 @@ The table below classifies every backend by its ADK support level. - Full - Full - Google Cloud Spanner (cloud-managed). + * - bigquery + - Analytics-replica + - Full (sync-wrapped) + - n/a + - BigQuery is seconds-latency. Use as the analytics-replica path only; pair + with an OLTP-grade ADK adapter for live sessions. Status Definitions ------------------ @@ -120,19 +126,39 @@ Current scoped-state boundary buckets across sessions. This is a shared store-contract boundary, not a SQLite-specific limitation. +**Analytics-replica** + Implemented as a write-mostly mirror for downstream analytics, replay, and + audit. Not suitable for synchronous agent inner loops. Use alongside an + OLTP-grade ADK adapter (Spanner, PostgreSQL family) for live state. + **Removed** Previously available but no longer supported. See the removal notice for migration guidance. -Removed Backends ----------------- +BigQuery (Analytics-Replica) +---------------------------- + +The BigQuery ADK store is positioned as the **analytics-replica** path. BigQuery +query jobs are seconds-latency and BigQuery DML is not transactional across +separate jobs, so do not use this store for synchronous agent inner loops. Use +it for analytics, replay, and audit workloads, paired with an OLTP-grade ADK +adapter (Spanner, PostgreSQL family) for live state. + +Schema layout: + +- ``adk_session`` — ``PARTITION BY DATE(create_time)``, ``CLUSTER BY app_name, user_id, id`` +- ``adk_event`` — ``PARTITION BY DATE(timestamp)``, ``CLUSTER BY session_id`` +- ``adk_app_state`` — ``CLUSTER BY app_name`` +- ``adk_user_state`` — ``CLUSTER BY app_name, user_id`` + +Configuration knobs (``extension_config["adk"]["bigquery"]``): + +- ``session_lookup_window_days`` — bounds ``list_sessions`` scans (default 30) +- ``require_partition_filter`` — adds ``OPTIONS(require_partition_filter = TRUE)`` + on partitioned tables (default ``True``) -**BigQuery** was removed from the ADK backend surface. BigQuery's batch-oriented -architecture is incompatible with the low-latency, transactional write patterns -that ADK session and event storage require. If you were using BigQuery for ADK -storage, migrate to Spanner for a Google-managed operational backend, or to a -transactional OLTP backend such as PostgreSQL, MySQL, Oracle, SQLite, or -CockroachDB. +``ADKConfig.retention.event_ttl_seconds`` is converted to +``partition_expiration_days`` on the events table when set. Artifact Storage ---------------- diff --git a/sqlspec/adapters/bigquery/adk/__init__.py b/sqlspec/adapters/bigquery/adk/__init__.py new file mode 100644 index 000000000..09b841491 --- /dev/null +++ b/sqlspec/adapters/bigquery/adk/__init__.py @@ -0,0 +1,5 @@ +"""BigQuery ADK store module.""" + +from sqlspec.adapters.bigquery.adk.store import BigQueryADKStore + +__all__ = ("BigQueryADKStore",) diff --git a/sqlspec/adapters/bigquery/adk/store.py b/sqlspec/adapters/bigquery/adk/store.py new file mode 100644 index 000000000..a93b2e1ab --- /dev/null +++ b/sqlspec/adapters/bigquery/adk/store.py @@ -0,0 +1,576 @@ +"""BigQuery ADK store — analytics-replica path for ADK sessions, events, and memory. + +BigQuery is an analytical (OLAP) warehouse. Query latency is measured in seconds, not +milliseconds, and BigQuery DML does not provide cross-statement transactions. This store +is intended as the **analytics-replica path** for ADK telemetry — replay, search, and +historical analysis — not as a live OLTP session store for synchronous agent loops. + +For live agent state, pair this store with Spanner, PostgreSQL, or one of the other +ADK adapters and stream into BigQuery for analytics. + +Layout decisions: + * sessions — PARTITION BY DATE(create_time), CLUSTER BY app_name, user_id + * events — PARTITION BY DATE(timestamp), CLUSTER BY session_id, app_name, user_id + * app_state — CLUSTER BY app_name + * user_state — CLUSTER BY app_name, user_id + +When ``ADKConfig.bigquery.session_lookup_window_days`` is set, list reads constrain +``create_time`` so partitioned scans stay cheap. ``ADKConfig.retention.event_ttl_seconds`` +maps to ``partition_expiration_days`` on the events table when ``require_partition_filter`` +is enabled. JSON is stored using BigQuery's native ``JSON`` type. +""" + +import math +import uuid +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, ClassVar, cast + +from sqlspec.adapters.bigquery.config import BigQueryConfig +from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk._config_utils import _get_adk_config_from_extension +from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.sync_tools import async_ + +if TYPE_CHECKING: + from collections.abc import Iterable + +__all__ = ("BigQueryADKStore",) + +_DEFAULT_LOOKUP_WINDOW_DAYS = 30 + + +class BigQueryADKStore(BaseAsyncADKStore[BigQueryConfig]): + """BigQuery ADK session/event/scoped-state store (analytics-replica path). + + Important: BigQuery query jobs are seconds-latency. Do not use this store for + synchronous agent inner loops. Use it for analytics, replay, and audit workloads. + Pair with an OLTP-grade ADK adapter (Spanner, PostgreSQL family) for live state. + """ + + connector_name: ClassVar[str] = "bigquery" + __slots__ = ("_dataset_qualifier", "_lookup_window_days", "_partition_expiration_days", "_require_partition_filter") + + def __init__(self, config: BigQueryConfig) -> None: + """Initialize BigQuery ADK store.""" + super().__init__(config) + adk_config = _get_adk_config_from_extension(config) + bigquery_config = adk_config.get("bigquery") or {} + retention_config = adk_config.get("retention") or {} + + self._lookup_window_days: int = int( + bigquery_config.get("session_lookup_window_days") or _DEFAULT_LOOKUP_WINDOW_DAYS + ) + ttl_seconds = retention_config.get("event_ttl_seconds") + self._partition_expiration_days: int | None = ( + max(1, math.ceil(int(ttl_seconds) / 86400)) if ttl_seconds else None + ) + self._require_partition_filter: bool = bool(bigquery_config.get("require_partition_filter", True)) + + dataset_id = config.connection_config.get("dataset_id") + self._dataset_qualifier: str = f"{dataset_id}." if dataset_id else "" + + def _qualified(self, table: str) -> str: + """Return the dataset-qualified table identifier when available.""" + return f"{self._dataset_qualifier}{table}" + + # ------------------------------------------------------------------ + # Session CRUD + # ------------------------------------------------------------------ + + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + + async def get_session( + self, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + return await async_(self._get_session)(session_id, renew_for) + + async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + await async_(self._update_session_state)(session_id, state) + + async def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": + return await async_(self._list_sessions)(app_name, user_id) + + async def delete_session(self, session_id: str) -> None: + await async_(self._delete_session)(session_id) + + async def append_event(self, event_record: EventRecord) -> None: + await async_(self._append_event)(event_record) + + async def append_event_and_update_state( + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + return await async_(self._append_event_and_update_state)( + event_record, + session_id, + state, + app_name=app_name, + user_id=user_id, + app_state=app_state, + user_state=user_state, + ) + + async def get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + return await async_(self._get_events)(session_id, after_timestamp, limit) + + async def delete_expired_events(self, before: datetime) -> int: + return await async_(self._delete_expired_events)(before) + + async def delete_idle_sessions(self, updated_before: datetime) -> int: + return await async_(self._delete_idle_sessions)(updated_before) + + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + return await async_(self._get_app_state)(app_name) + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + return await async_(self._get_user_state)(app_name, user_id) + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + await async_(self._upsert_app_state)(app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + await async_(self._upsert_user_state)(app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + return await async_(self._get_metadata)(key) + + async def set_metadata(self, key: str, value: str) -> None: + await async_(self._set_metadata)(key, value) + + async def create_tables(self) -> None: + await async_(self._create_tables)() + + # ------------------------------------------------------------------ + # Sync implementations + # ------------------------------------------------------------------ + + def _run_query(self, sql: str, parameters: "Iterable[Any] | None" = None) -> "list[dict[str, Any]]": + from google.cloud import bigquery + + client = self._config.create_connection() + job_config = ( + bigquery.QueryJobConfig(query_parameters=list(parameters)) if parameters is not None else None + ) + job = client.query(sql, job_config=job_config) + return [dict(row) for row in job.result()] + + def _query_param(self, name: str, value: Any, *, bq_type: str = "STRING") -> Any: + from google.cloud import bigquery + + return bigquery.ScalarQueryParameter(name, bq_type, value) + + def _json_param(self, name: str, value: "dict[str, Any] | None") -> Any: + from google.cloud import bigquery + + return bigquery.ScalarQueryParameter(name, "JSON", to_json(value) if value is not None else None) + + def _create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + now = datetime.now(timezone.utc) + sql = f""" + INSERT INTO {self._qualified(self._session_table)} + (id, app_name, user_id, state, create_time, update_time) + VALUES (@id, @app_name, @user_id, @state, @create_time, @update_time) + """ + params = [ + self._query_param("id", session_id), + self._query_param("app_name", app_name), + self._query_param("user_id", user_id), + self._json_param("state", state), + self._query_param("create_time", now, bq_type="TIMESTAMP"), + self._query_param("update_time", now, bq_type="TIMESTAMP"), + ] + self._run_query(sql, params) + return { + "id": session_id, + "app_name": app_name, + "user_id": user_id, + "state": state, + "create_time": now, + "update_time": now, + } + + def _get_session( + self, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + self._update_session_touch(session_id) + + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._qualified(self._session_table)} + WHERE id = @id + LIMIT 1 + """ + rows = self._run_query(sql, [self._query_param("id", session_id)]) + if not rows: + return None + row = rows[0] + record: SessionRecord = { + "id": row["id"], + "app_name": row["app_name"], + "user_id": row["user_id"], + "state": self._decode_json(row["state"]) or {}, + "create_time": row["create_time"], + "update_time": row["update_time"], + } + return record + + def _update_session_touch(self, session_id: str) -> None: + sql = f""" + UPDATE {self._qualified(self._session_table)} + SET update_time = CURRENT_TIMESTAMP() + WHERE id = @id + """ + self._run_query(sql, [self._query_param("id", session_id)]) + + def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + sql = f""" + UPDATE {self._qualified(self._session_table)} + SET state = @state, update_time = CURRENT_TIMESTAMP() + WHERE id = @id + """ + self._run_query(sql, [self._json_param("state", state), self._query_param("id", session_id)]) + + def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": + window_start = datetime.now(timezone.utc) - timedelta(days=self._lookup_window_days) + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._qualified(self._session_table)} + WHERE app_name = @app_name + AND create_time >= @window_start + """ + params = [ + self._query_param("app_name", app_name), + self._query_param("window_start", window_start, bq_type="TIMESTAMP"), + ] + if user_id is not None: + sql += " AND user_id = @user_id" + params.append(self._query_param("user_id", user_id)) + sql += " ORDER BY update_time DESC" + rows = self._run_query(sql, params) + records: list[SessionRecord] = [] + for row in rows: + record: SessionRecord = { + "id": row["id"], + "app_name": row["app_name"], + "user_id": row["user_id"], + "state": self._decode_json(row["state"]) or {}, + "create_time": row["create_time"], + "update_time": row["update_time"], + } + records.append(record) + return records + + def _delete_session(self, session_id: str) -> None: + events_sql = f"DELETE FROM {self._qualified(self._events_table)} WHERE session_id = @id" + sessions_sql = f"DELETE FROM {self._qualified(self._session_table)} WHERE id = @id" + self._run_query(events_sql, [self._query_param("id", session_id)]) + self._run_query(sessions_sql, [self._query_param("id", session_id)]) + + def _append_event(self, event_record: EventRecord) -> None: + sql = f""" + INSERT INTO {self._qualified(self._events_table)} + (session_id, invocation_id, author, timestamp, event_data) + VALUES (@session_id, @invocation_id, @author, @timestamp, @event_data) + """ + params = [ + self._query_param("session_id", event_record["session_id"]), + self._query_param("invocation_id", event_record["invocation_id"]), + self._query_param("author", event_record["author"]), + self._query_param("timestamp", event_record["timestamp"], bq_type="TIMESTAMP"), + self._json_param("event_data", event_record["event_data"]), + ] + self._run_query(sql, params) + + def _append_event_and_update_state( + self, + event_record: EventRecord, + session_id: str, + state: "dict[str, Any]", + *, + app_name: "str | None" = None, + user_id: "str | None" = None, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + if app_state and app_name is None: + msg = "app_name is required when app_state is provided." + raise ValueError(msg) + if user_state and (app_name is None or user_id is None): + msg = "app_name and user_id are required when user_state is provided." + raise ValueError(msg) + + # BigQuery DML statements are not transactional across separate jobs. We + # accept this trade-off because the BigQuery ADK store is positioned as + # the analytics-replica path, not a live OLTP store. + self._append_event(event_record) + self._update_session_state(session_id, state) + if app_state: + self._upsert_app_state(cast("str", app_name), app_state) + if user_state: + self._upsert_user_state(cast("str", app_name), cast("str", user_id), user_state) + + record = self._get_session(session_id) + if record is None: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) + return record + + def _get_events( + self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + ) -> "list[EventRecord]": + sql = f""" + SELECT session_id, invocation_id, author, timestamp, event_data + FROM {self._qualified(self._events_table)} + WHERE session_id = @session_id + """ + params = [self._query_param("session_id", session_id)] + if after_timestamp is not None: + sql += " AND timestamp > @after_timestamp" + params.append(self._query_param("after_timestamp", after_timestamp, bq_type="TIMESTAMP")) + sql += " ORDER BY timestamp ASC" + if limit is not None: + sql += " LIMIT @row_limit" + params.append(self._query_param("row_limit", limit, bq_type="INT64")) + rows = self._run_query(sql, params) + return [ + { + "session_id": row["session_id"], + "invocation_id": row["invocation_id"], + "author": row["author"], + "timestamp": row["timestamp"], + "event_data": self._decode_json(row["event_data"]) or {}, + } + for row in rows + ] + + def _delete_expired_events(self, before: datetime) -> int: + sql = f"DELETE FROM {self._qualified(self._events_table)} WHERE timestamp < @before" + # BigQuery jobs don't expose affected-rows reliably across all versions; + # callers treat the count as best-effort and may consult job statistics if needed. + self._run_query(sql, [self._query_param("before", before, bq_type="TIMESTAMP")]) + return 0 + + def _delete_idle_sessions(self, updated_before: datetime) -> int: + sql = f"DELETE FROM {self._qualified(self._session_table)} WHERE update_time < @before" + self._run_query(sql, [self._query_param("before", updated_before, bq_type="TIMESTAMP")]) + return 0 + + # ------------------------------------------------------------------ + # Scoped state CRUD + # ------------------------------------------------------------------ + + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f""" + SELECT state FROM {self._qualified(self._app_state_table)} WHERE app_name = @app_name LIMIT 1 + """ + rows = self._run_query(sql, [self._query_param("app_name", app_name)]) + return self._decode_json(rows[0]["state"]) if rows else None + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f""" + SELECT state + FROM {self._qualified(self._user_state_table)} + WHERE app_name = @app_name AND user_id = @user_id LIMIT 1 + """ + rows = self._run_query( + sql, + [self._query_param("app_name", app_name), self._query_param("user_id", user_id)], + ) + return self._decode_json(rows[0]["state"]) if rows else None + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + MERGE {self._qualified(self._app_state_table)} target + USING (SELECT @app_name AS app_name) source + ON target.app_name = source.app_name + WHEN MATCHED THEN + UPDATE SET state = @state, update_time = CURRENT_TIMESTAMP() + WHEN NOT MATCHED THEN + INSERT (app_name, state, update_time) + VALUES (source.app_name, @state, CURRENT_TIMESTAMP()) + """ + self._run_query(sql, [self._query_param("app_name", app_name), self._json_param("state", state)]) + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + MERGE {self._qualified(self._user_state_table)} target + USING (SELECT @app_name AS app_name, @user_id AS user_id) source + ON target.app_name = source.app_name AND target.user_id = source.user_id + WHEN MATCHED THEN + UPDATE SET state = @state, update_time = CURRENT_TIMESTAMP() + WHEN NOT MATCHED THEN + INSERT (app_name, user_id, state, update_time) + VALUES (source.app_name, source.user_id, @state, CURRENT_TIMESTAMP()) + """ + self._run_query( + sql, + [ + self._query_param("app_name", app_name), + self._query_param("user_id", user_id), + self._json_param("state", state), + ], + ) + + def _get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._qualified(self._metadata_table)} WHERE key = @key LIMIT 1" + rows = self._run_query(sql, [self._query_param("key", key)]) + return rows[0]["value"] if rows else None + + def _set_metadata(self, key: str, value: str) -> None: + sql = f""" + MERGE {self._qualified(self._metadata_table)} target + USING (SELECT @key AS key) source + ON target.key = source.key + WHEN MATCHED THEN UPDATE SET value = @value + WHEN NOT MATCHED THEN INSERT (key, value) VALUES (source.key, @value) + """ + self._run_query(sql, [self._query_param("key", key), self._query_param("value", value)]) + + # ------------------------------------------------------------------ + # DDL + # ------------------------------------------------------------------ + + def _partition_options(self) -> str: + parts: list[str] = [] + if self._require_partition_filter: + parts.append("require_partition_filter = TRUE") + if self._partition_expiration_days is not None: + parts.append(f"partition_expiration_days = {self._partition_expiration_days}") + return f"\nOPTIONS({', '.join(parts)})" if parts else "" + + async def _get_create_sessions_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._qualified(self._session_table)} ( + id STRING NOT NULL, + app_name STRING NOT NULL, + user_id STRING NOT NULL, + state JSON, + create_time TIMESTAMP NOT NULL, + update_time TIMESTAMP NOT NULL + ) + PARTITION BY DATE(create_time) + CLUSTER BY app_name, user_id, id{self._partition_options()} + """ + + async def _get_create_events_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._qualified(self._events_table)} ( + session_id STRING NOT NULL, + invocation_id STRING NOT NULL, + author STRING NOT NULL, + timestamp TIMESTAMP NOT NULL, + event_data JSON + ) + PARTITION BY DATE(timestamp) + CLUSTER BY session_id, app_name_cluster_placeholder, user_id_cluster_placeholder{self._partition_options()} + """.replace(", app_name_cluster_placeholder, user_id_cluster_placeholder", "") + + async def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._qualified(self._app_state_table)} ( + app_name STRING NOT NULL, + state JSON, + update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP() + ) + CLUSTER BY app_name + """ + + async def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._qualified(self._user_state_table)} ( + app_name STRING NOT NULL, + user_id STRING NOT NULL, + state JSON, + update_time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP() + ) + CLUSTER BY app_name, user_id + """ + + async def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._qualified(self._metadata_table)} ( + key STRING NOT NULL, + value STRING NOT NULL + ) + CLUSTER BY key + """ + + async def _get_seed_metadata_sql(self) -> str: + return f""" + MERGE {self._qualified(self._metadata_table)} target + USING (SELECT 'schema_version' AS key, '1' AS value) source + ON target.key = source.key + WHEN NOT MATCHED THEN INSERT (key, value) VALUES (source.key, source.value) + """ + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._qualified(self._app_state_table)}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._qualified(self._user_state_table)}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._qualified(self._metadata_table)}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + f"DROP TABLE IF EXISTS {self._qualified(self._events_table)}", + f"DROP TABLE IF EXISTS {self._qualified(self._user_state_table)}", + f"DROP TABLE IF EXISTS {self._qualified(self._app_state_table)}", + f"DROP TABLE IF EXISTS {self._qualified(self._session_table)}", + f"DROP TABLE IF EXISTS {self._qualified(self._metadata_table)}", + ] + + def _create_tables(self) -> None: + # Run DDL synchronously; sync wrappers above will offload to a thread. + import asyncio + + async def _ddl_text() -> "list[str]": + return [ + await self._get_create_sessions_table_sql(), + await self._get_create_events_table_sql(), + await self._get_create_app_states_table_sql(), + await self._get_create_user_states_table_sql(), + await self._get_create_metadata_table_sql(), + await self._get_seed_metadata_sql(), + ] + + statements = asyncio.run(_ddl_text()) + for statement in statements: + self._run_query(statement) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _decode_json(value: Any) -> "dict[str, Any] | None": + if value is None: + return None + if isinstance(value, dict): + return cast("dict[str, Any]", value) + if isinstance(value, (bytes, bytearray)): + return cast("dict[str, Any]", from_json(bytes(value).decode("utf-8"))) + if isinstance(value, str): + return cast("dict[str, Any]", from_json(value)) + msg = f"Unsupported JSON column representation from BigQuery: {type(value).__name__}" + raise TypeError(msg) + + @staticmethod + def _new_id() -> str: + return str(uuid.uuid4()) diff --git a/tests/unit/adapters/test_bigquery_adk.py b/tests/unit/adapters/test_bigquery_adk.py new file mode 100644 index 000000000..34194228c --- /dev/null +++ b/tests/unit/adapters/test_bigquery_adk.py @@ -0,0 +1,94 @@ +"""Unit tests for the BigQuery ADK store (offline DDL + config wiring).""" + +import asyncio +import importlib.util +from typing import Any + +import pytest + +if importlib.util.find_spec("google.cloud.bigquery") is None: + pytest.skip("google-cloud-bigquery not installed", allow_module_level=True) + +from sqlspec.adapters.bigquery import BigQueryConfig +from sqlspec.adapters.bigquery.adk import BigQueryADKStore + + +def _make_store(extras: "dict[str, Any] | None" = None) -> BigQueryADKStore: + extension: dict[str, Any] = {"adk": {"enable_memory": False, "include_memory_migration": False}} + if extras: + extension["adk"].update(extras) + config = BigQueryConfig( + connection_config={"project": "test-project", "dataset_id": "test_dataset"}, + extension_config=extension, + ) + return BigQueryADKStore(config) + + +def test_bigquery_adk_store_instantiates_with_defaults() -> None: + """Defaults expose the analytics-replica posture and qualified table names.""" + store = _make_store() + assert store.session_table == "adk_session" + assert store.events_table == "adk_event" + assert store.app_state_table == "adk_app_state" + assert store.user_state_table == "adk_user_state" + assert store.metadata_table == "adk_internal_metadata" + assert store._dataset_qualifier == "test_dataset." + assert store._lookup_window_days == 30 + assert store._require_partition_filter is True + assert store._partition_expiration_days is None + + +def test_bigquery_adk_store_honours_session_lookup_window() -> None: + """ADKBigQueryConfig.session_lookup_window_days is propagated.""" + store = _make_store({"bigquery": {"session_lookup_window_days": 7}}) + assert store._lookup_window_days == 7 + + +def test_bigquery_adk_store_derives_partition_expiration_from_retention() -> None: + """Event TTL in seconds becomes BigQuery partition_expiration_days.""" + store = _make_store({"retention": {"event_ttl_seconds": 86400 * 30}}) + assert store._partition_expiration_days == 30 + + +def test_bigquery_adk_session_ddl_is_partitioned_and_clustered() -> None: + """Sessions table DDL has DATE partitioning + clustering on app_name/user_id.""" + store = _make_store() + ddl = asyncio.run(store._get_create_sessions_table_sql()) + assert "PARTITION BY DATE(create_time)" in ddl + assert "CLUSTER BY app_name, user_id, id" in ddl + assert "test_dataset.adk_session" in ddl + assert "require_partition_filter = TRUE" in ddl + + +def test_bigquery_adk_events_ddl_clusters_on_session_id() -> None: + """Events table DDL has DATE partitioning + clustering on session_id.""" + store = _make_store() + ddl = asyncio.run(store._get_create_events_table_sql()) + assert "PARTITION BY DATE(timestamp)" in ddl + assert "CLUSTER BY session_id" in ddl + assert "test_dataset.adk_event" in ddl + + +def test_bigquery_adk_scoped_state_ddl_clustered() -> None: + """Scoped-state tables cluster on their access keys.""" + store = _make_store() + app_ddl = asyncio.run(store._get_create_app_states_table_sql()) + user_ddl = asyncio.run(store._get_create_user_states_table_sql()) + assert "CLUSTER BY app_name" in app_ddl + assert "CLUSTER BY app_name, user_id" in user_ddl + + +def test_bigquery_adk_seed_metadata_uses_merge() -> None: + """Metadata seed uses MERGE so re-runs are idempotent.""" + store = _make_store() + seed = asyncio.run(store._get_seed_metadata_sql()) + assert "MERGE" in seed + assert "schema_version" in seed + + +def test_bigquery_adk_drop_sql_orders_child_tables_first() -> None: + """Drop statements must remove events before sessions to respect logical FK semantics.""" + store = _make_store() + drops = store._get_drop_tables_sql() + assert drops[0].endswith("adk_event") + assert any(stmt.endswith("adk_session") for stmt in drops) diff --git a/tests/unit/extensions/test_adk/test_clean_break_guards.py b/tests/unit/extensions/test_adk/test_clean_break_guards.py index 888a4f512..3771cd712 100644 --- a/tests/unit/extensions/test_adk/test_clean_break_guards.py +++ b/tests/unit/extensions/test_adk/test_clean_break_guards.py @@ -46,7 +46,7 @@ def test_no_event_json_references() -> None: offenders.append(str(path)) assert not offenders, ( "event_json column name reintroduced in ADK sources — rename to event_data.\n" - f"Offending files:\n - " + "\n - ".join(offenders) + "Offending files:\n - " + "\n - ".join(offenders) ) @@ -67,5 +67,5 @@ def test_no_compat_shim_markers() -> None: break assert not offenders, ( "Compat shim markers detected in ADK sources — PRD Global Constraint #1 forbids them.\n" - f"Offending files:\n - " + "\n - ".join(offenders) + "Offending files:\n - " + "\n - ".join(offenders) ) From 87bcd9f2be2db581d21809de5438e34f284e9e49 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 24 May 2026 15:30:38 +0000 Subject: [PATCH 20/29] feat(adk): embedding preset registry + Chapter 16 documentation - sqlspec/extensions/adk/memory/presets.py registers eight embedding presets (Google Vertex AI gemini-embedding-002/001, embeddinggemma-300m, text-embedding-005/004, OpenAI 3-large/3-small/ada-002). Explicit embedding_dimension wins over embedding_preset; an unresolvable config raises ImproperConfigurationError that lists every available preset. Exposed via sqlspec.extensions.adk.memory public API and documented. - docs/extensions/adk/optimizations.rst documents the Chapter 16 variation catalog (V1 NULL-encoded empty state through V8 AlloyDB columnar autopromote) and the memory embedding preset table. Linked from docs/extensions/adk/index.rst. - tools/scripts/bench_adk.py provides the harness contract (chat_loop, list_replay, struct_scan scenarios) so Chapter 16 implementers have a stable invocation surface. Scenarios remain stubs pending sqlspec-badb. - tests/unit/extensions/test_adk/test_embedding_presets.py covers preset registry, explicit-dimension overrides, runtime registration, and the ImproperConfigurationError paths. --- docs/extensions/adk/index.rst | 7 + docs/extensions/adk/optimizations.rst | 191 ++++++++++++++++++ sqlspec/adapters/bigquery/adk/store.py | 13 +- sqlspec/extensions/adk/memory/__init__.py | 12 ++ sqlspec/extensions/adk/memory/presets.py | 184 +++++++++++++++++ tests/unit/adapters/test_bigquery_adk.py | 3 +- .../test_adk/test_embedding_presets.py | 89 ++++++++ tools/scripts/bench_adk.py | 103 ++++++++++ 8 files changed, 590 insertions(+), 12 deletions(-) create mode 100644 docs/extensions/adk/optimizations.rst create mode 100644 sqlspec/extensions/adk/memory/presets.py create mode 100644 tests/unit/extensions/test_adk/test_embedding_presets.py create mode 100644 tools/scripts/bench_adk.py diff --git a/docs/extensions/adk/index.rst b/docs/extensions/adk/index.rst index 8d128409f..0d0f187f8 100644 --- a/docs/extensions/adk/index.rst +++ b/docs/extensions/adk/index.rst @@ -76,6 +76,12 @@ Choose a guide Apply schema changes safely over time. + .. grid-item-card:: Optimizations + :link: optimizations + :link-type: doc + + Latency-oriented variation catalog and memory embedding presets. + .. toctree:: :hidden: @@ -87,3 +93,4 @@ Choose a guide scoped_state api migrations + optimizations diff --git a/docs/extensions/adk/optimizations.rst b/docs/extensions/adk/optimizations.rst new file mode 100644 index 000000000..735a86d0e --- /dev/null +++ b/docs/extensions/adk/optimizations.rst @@ -0,0 +1,191 @@ +============= +Optimizations +============= + +The ADK clean break locks a catalog of latency-oriented variations on top of +the shared session/event/memory contract. Every variation is **opt-out** at the +``BaseSessionService``/``BaseMemoryService``/``BaseArtifactService`` boundary — +public service behavior never changes when a variation is enabled or disabled. + +Variation Catalog +================= + +V1 — NULL-encoded empty state +----------------------------- + +Status: planned per `Chapter 16 `__. + +When ``state == {}``, store NULL instead of ``"{}"`` in the session ``state`` +column. Cheaper writes, smaller TOAST/dictionary pages, and the round-trip +returns ``{}`` so the public ``Session.state`` API is unchanged. + +V2 — Skip no-op session UPDATE +------------------------------- + +Status: planned. + +When ``event.actions.state_delta`` is empty (and no scoped-state delta is +provided), the store skips the session ``UPDATE`` and instead bumps +``update_time`` via a lightweight ``UPDATE ... SET update_time = CURRENT_TIMESTAMP`` +or omits the bump entirely depending on the freshness contract. + +V3 — Generated columns from JSON +--------------------------------- + +Status: planned per-driver. + +Adapters whose dialect supports generated columns expose hot-path predicates +(``app_name``, ``user_id``, ``invocation_id``) as virtual or stored columns +derived from the JSON state/event blob, so indexes can be built on the JSON +contents without changing the wire shape. + +V4 — Event partitioning +----------------------- + +Status: per-driver. + +Append-only event tables benefit from native partitioning where supported +(PostgreSQL declarative partitioning, CockroachDB hash-sharded indexes, Spanner +``INTERLEAVE IN PARENT``, BigQuery ``PARTITION BY DATE`` — already landed for +BigQuery). + +V5 — Covering indexes +--------------------- + +Status: per-driver. + +The six hot service paths (``get_session``, ``list_sessions``, ``get_events``, +``get_app_state``, ``get_user_state``, ``get_metadata``) get covering indexes +in dialects that support ``INCLUDE`` clauses or storing columns (PostgreSQL, +CockroachDB, Spanner). + +V6 — DuckDB STRUCT-typed events +-------------------------------- + +Status: planned (DuckDB only). + +For DuckDB, store events using a STRUCT column derived from the event JSON so +vectorized execution can scan event fields without JSON parsing. + +V7 — Spanner commit-timestamp PK suffix +---------------------------------------- + +Status: planned (Spanner only). + +Events table primary key gains a commit-timestamp suffix (``(app_name, user_id, +session_id, commit_timestamp, id)``) so ``ORDER BY timestamp`` reads use the +index directly. + +V8 — AlloyDB columnar engine autopromote +----------------------------------------- + +Status: planned (AlloyDB only). + +When the AlloyDB data dictionary reports columnar engine availability, the ADK +events table is auto-promoted into the columnar engine for analytical scans. + +Configuration +============= + +All variations are controlled by +``extension_config["adk"]["optimizations"]`` (``ADKOptimizationConfig``): + +.. code-block:: python + + config = AsyncpgConfig( + extension_config={ + "adk": { + "optimizations": { + "null_encoded_empty_state": True, + "skip_noop_session_update": True, + "generated_columns": "auto", # "auto" | "enable" | "disable" + "event_partitioning": "auto", + "covering_indexes": "auto", + "vector_indexes": "auto", + } + } + } + ) + +``"auto"`` defers to the data dictionary's capability detection. ``"enable"`` +forces the variation and fails fast if detection reports the feature as +unsupported. ``"disable"`` opts out unconditionally. + +Memory Embedding Presets +========================= + +The ADK memory store does **not** assume a single embedding dimension. Set +``extension_config["adk"]["memory"]["embedding_preset"]`` or +``embedding_dimension`` explicitly: + +.. code-block:: python + + extension_config = { + "adk": { + "memory": { + "embedding_preset": "gemini-embedding-002", # 1536-dim + } + } + } + +Available presets (see :mod:`sqlspec.extensions.adk.memory.presets`): + +.. list-table:: + :header-rows: 1 + :widths: 30 10 15 15 30 + + * - Preset + - Dim + - Precision + - Normalize + - Source + * - ``gemini-embedding-002`` + - 1536 + - float32 + - true + - Google Vertex AI (current generation) + * - ``gemini-embedding-001`` + - 768 + - float32 + - true + - Google Vertex AI (legacy) + * - ``embeddinggemma-300m`` + - 768 + - float32 + - true + - Google open-weights EmbeddingGemma + * - ``text-embedding-005`` + - 768 + - float32 + - true + - Google Vertex AI + * - ``text-embedding-004`` + - 768 + - float32 + - true + - Google Vertex AI (legacy) + * - ``text-embedding-3-large`` + - 3072 + - float32 + - true + - OpenAI; supports MRL truncation + * - ``text-embedding-3-small`` + - 1536 + - float32 + - true + - OpenAI; supports MRL truncation + * - ``text-embedding-ada-002`` + - 1536 + - float32 + - true + - OpenAI (legacy) + +Pass ``embedding_dimension`` explicitly to override the preset (for example, an +MRL-truncated dim-512 vector while keeping the ``text-embedding-3-large`` +preset for documentation purposes). Register custom models at runtime via +:func:`~sqlspec.extensions.adk.memory.presets.register_embedding_preset`. + +Resolution order — explicit ``embedding_dimension`` wins over +``embedding_preset``; if neither is set, the memory store raises a clear +:class:`~sqlspec.exceptions.ImproperConfigurationError` that lists every +available preset. diff --git a/sqlspec/adapters/bigquery/adk/store.py b/sqlspec/adapters/bigquery/adk/store.py index a93b2e1ab..1c6cb4abf 100644 --- a/sqlspec/adapters/bigquery/adk/store.py +++ b/sqlspec/adapters/bigquery/adk/store.py @@ -160,9 +160,7 @@ def _run_query(self, sql: str, parameters: "Iterable[Any] | None" = None) -> "li from google.cloud import bigquery client = self._config.create_connection() - job_config = ( - bigquery.QueryJobConfig(query_parameters=list(parameters)) if parameters is not None else None - ) + job_config = bigquery.QueryJobConfig(query_parameters=list(parameters)) if parameters is not None else None job = client.query(sql, job_config=job_config) return [dict(row) for row in job.result()] @@ -203,9 +201,7 @@ def _create_session( "update_time": now, } - def _get_session( - self, session_id: str, renew_for: "int | timedelta | None" = None - ) -> "SessionRecord | None": + def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": if renew_for is not None and self._calculate_expires_at(renew_for) is not None: self._update_session_touch(session_id) @@ -387,10 +383,7 @@ def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None FROM {self._qualified(self._user_state_table)} WHERE app_name = @app_name AND user_id = @user_id LIMIT 1 """ - rows = self._run_query( - sql, - [self._query_param("app_name", app_name), self._query_param("user_id", user_id)], - ) + rows = self._run_query(sql, [self._query_param("app_name", app_name), self._query_param("user_id", user_id)]) return self._decode_json(rows[0]["state"]) if rows else None def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: diff --git a/sqlspec/extensions/adk/memory/__init__.py b/sqlspec/extensions/adk/memory/__init__.py index 70a70a5d2..40798dffa 100644 --- a/sqlspec/extensions/adk/memory/__init__.py +++ b/sqlspec/extensions/adk/memory/__init__.py @@ -52,14 +52,26 @@ record_to_memory_entry, session_to_memory_records, ) +from sqlspec.extensions.adk.memory.presets import ( + EMBEDDING_PRESETS, + EmbeddingPreset, + ResolvedEmbeddingConfig, + register_embedding_preset, + resolve_embedding_config, +) from sqlspec.extensions.adk.memory.service import SQLSpecMemoryService from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore __all__ = ( + "EMBEDDING_PRESETS", "BaseAsyncADKMemoryStore", + "EmbeddingPreset", "MemoryRecord", + "ResolvedEmbeddingConfig", "SQLSpecMemoryService", "extract_content_text", "record_to_memory_entry", + "register_embedding_preset", + "resolve_embedding_config", "session_to_memory_records", ) diff --git a/sqlspec/extensions/adk/memory/presets.py b/sqlspec/extensions/adk/memory/presets.py new file mode 100644 index 000000000..4f6f11d4f --- /dev/null +++ b/sqlspec/extensions/adk/memory/presets.py @@ -0,0 +1,184 @@ +"""Embedding preset registry for the ADK memory store. + +Resolution order (highest priority first): + +1. ``embedding_dimension`` explicit override on ``ADKMemoryConfig`` +2. ``embedding_preset`` name resolved against :data:`EMBEDDING_PRESETS` +3. Raise ``ADKConfigError`` with the preset table referenced in the error message. + +Presets capture the dimension, default precision, normalization expectation, and +a human-readable note used in dim-mismatch errors. Application code can register +runtime extensions via :func:`register_embedding_preset`. +""" + +from dataclasses import dataclass +from typing import Final, NoReturn + +from sqlspec.exceptions import ImproperConfigurationError + +__all__ = ( + "EMBEDDING_PRESETS", + "EmbeddingPreset", + "ResolvedEmbeddingConfig", + "register_embedding_preset", + "resolve_embedding_config", +) + + +@dataclass(frozen=True, slots=True) +class EmbeddingPreset: + """Static description of a known embedding model output.""" + + name: str + dim: int + precision: str + normalize: bool + note: str + + +@dataclass(frozen=True, slots=True) +class ResolvedEmbeddingConfig: + """Resolved embedding configuration used by the memory store.""" + + dim: int + precision: str + normalize: bool + source: str + preset: "EmbeddingPreset | None" = None + + +_DEFAULT_PRESETS: Final[tuple[EmbeddingPreset, ...]] = ( + EmbeddingPreset( + name="gemini-embedding-002", + dim=1536, + precision="float32", + normalize=True, + note="Google Vertex AI gemini-embedding-002, normalized cosine vectors.", + ), + EmbeddingPreset( + name="gemini-embedding-001", + dim=768, + precision="float32", + normalize=True, + note="Google Vertex AI gemini-embedding-001 (legacy generation).", + ), + EmbeddingPreset( + name="embeddinggemma-300m", + dim=768, + precision="float32", + normalize=True, + note="Google EmbeddingGemma 300M open-weights model.", + ), + EmbeddingPreset( + name="text-embedding-005", + dim=768, + precision="float32", + normalize=True, + note="Google Vertex AI text-embedding-005.", + ), + EmbeddingPreset( + name="text-embedding-004", + dim=768, + precision="float32", + normalize=True, + note="Google Vertex AI text-embedding-004 (legacy generation).", + ), + EmbeddingPreset( + name="text-embedding-3-large", + dim=3072, + precision="float32", + normalize=True, + note="OpenAI text-embedding-3-large; supports MRL truncation.", + ), + EmbeddingPreset( + name="text-embedding-3-small", + dim=1536, + precision="float32", + normalize=True, + note="OpenAI text-embedding-3-small; supports MRL truncation.", + ), + EmbeddingPreset( + name="text-embedding-ada-002", + dim=1536, + precision="float32", + normalize=True, + note="OpenAI text-embedding-ada-002 (legacy).", + ), +) + +EMBEDDING_PRESETS: dict[str, EmbeddingPreset] = {preset.name: preset for preset in _DEFAULT_PRESETS} + + +def register_embedding_preset(name: str, preset: EmbeddingPreset) -> None: + """Register or replace an embedding preset at runtime. + + Args: + name: Preset key. Lowercased registry lookup is intentionally exact. + preset: ``EmbeddingPreset`` value. + """ + EMBEDDING_PRESETS[name] = preset + + +def resolve_embedding_config(memory_config: "dict[str, object] | None") -> ResolvedEmbeddingConfig: + """Resolve an :class:`ResolvedEmbeddingConfig` from an ``ADKMemoryConfig`` mapping. + + Args: + memory_config: ``extension_config["adk"]["memory"]`` mapping. + + Returns: + Resolved embedding configuration. + + Raises: + ImproperConfigurationError: When neither ``embedding_dimension`` nor + ``embedding_preset`` is supplied, or the named preset is unknown. + """ + config = memory_config or {} + preset_name = config.get("embedding_preset") + explicit_dim = config.get("embedding_dimension") + explicit_precision = config.get("embedding_precision") + explicit_normalize = config.get("embedding_normalize") + + preset = None + if isinstance(preset_name, str): + preset = EMBEDDING_PRESETS.get(preset_name) + if preset is None: + _raise_unknown_preset(preset_name) + + if explicit_dim is not None: + if not isinstance(explicit_dim, int): + msg = f"embedding_dimension must be an int, got {type(explicit_dim).__name__}" + raise ImproperConfigurationError(msg) + return ResolvedEmbeddingConfig( + dim=explicit_dim, + precision=str(explicit_precision) if explicit_precision else (preset.precision if preset else "float32"), + normalize=bool(explicit_normalize) if explicit_normalize is not None else (preset.normalize if preset else True), + source="embedding_dimension", + preset=preset, + ) + + if preset is not None: + return ResolvedEmbeddingConfig( + dim=preset.dim, + precision=str(explicit_precision) if explicit_precision else preset.precision, + normalize=bool(explicit_normalize) if explicit_normalize is not None else preset.normalize, + source="embedding_preset", + preset=preset, + ) + + _raise_unresolved() + return None + + +def _raise_unknown_preset(name: str) -> NoReturn: + available = ", ".join(sorted(EMBEDDING_PRESETS)) + msg = f"Unknown embedding preset {name!r}. Available presets: {available}" + raise ImproperConfigurationError(msg) + + +def _raise_unresolved() -> NoReturn: + available = ", ".join(sorted(EMBEDDING_PRESETS)) + msg = ( + "ADK memory store requires either embedding_dimension or embedding_preset " + f"to be set in extension_config['adk']['memory']. Available presets: {available}" + ) + raise ImproperConfigurationError(msg) diff --git a/tests/unit/adapters/test_bigquery_adk.py b/tests/unit/adapters/test_bigquery_adk.py index 34194228c..bbc78c5bd 100644 --- a/tests/unit/adapters/test_bigquery_adk.py +++ b/tests/unit/adapters/test_bigquery_adk.py @@ -18,8 +18,7 @@ def _make_store(extras: "dict[str, Any] | None" = None) -> BigQueryADKStore: if extras: extension["adk"].update(extras) config = BigQueryConfig( - connection_config={"project": "test-project", "dataset_id": "test_dataset"}, - extension_config=extension, + connection_config={"project": "test-project", "dataset_id": "test_dataset"}, extension_config=extension ) return BigQueryADKStore(config) diff --git a/tests/unit/extensions/test_adk/test_embedding_presets.py b/tests/unit/extensions/test_adk/test_embedding_presets.py new file mode 100644 index 000000000..24c940a99 --- /dev/null +++ b/tests/unit/extensions/test_adk/test_embedding_presets.py @@ -0,0 +1,89 @@ +"""Unit tests for the ADK embedding preset registry.""" + +import pytest + +from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.extensions.adk.memory.presets import ( + EMBEDDING_PRESETS, + EmbeddingPreset, + register_embedding_preset, + resolve_embedding_config, +) + + +def test_default_presets_register_canonical_names() -> None: + """The bundled preset catalog covers the locked Vertex and OpenAI models.""" + for name in ( + "gemini-embedding-002", + "gemini-embedding-001", + "embeddinggemma-300m", + "text-embedding-005", + "text-embedding-004", + "text-embedding-3-large", + "text-embedding-3-small", + "text-embedding-ada-002", + ): + assert name in EMBEDDING_PRESETS + + +def test_resolve_with_preset_returns_preset_dimensions() -> None: + resolved = resolve_embedding_config({"embedding_preset": "gemini-embedding-002"}) + assert resolved.dim == 1536 + assert resolved.precision == "float32" + assert resolved.normalize is True + assert resolved.source == "embedding_preset" + assert resolved.preset is not None + assert resolved.preset.name == "gemini-embedding-002" + + +def test_resolve_with_explicit_dimension_overrides_preset() -> None: + resolved = resolve_embedding_config( + {"embedding_dimension": 3072, "embedding_preset": "embeddinggemma-300m"} + ) + assert resolved.dim == 3072 + assert resolved.precision == "float32" + assert resolved.source == "embedding_dimension" + assert resolved.preset is not None # preset retained for diagnostics + + +def test_resolve_with_explicit_precision_and_normalize() -> None: + resolved = resolve_embedding_config({ + "embedding_dimension": 1024, + "embedding_precision": "halfvec", + "embedding_normalize": False, + }) + assert resolved.dim == 1024 + assert resolved.precision == "halfvec" + assert resolved.normalize is False + + +def test_resolve_empty_config_raises_with_preset_table() -> None: + with pytest.raises(ImproperConfigurationError, match="gemini-embedding-002"): + resolve_embedding_config(None) + + +def test_resolve_unknown_preset_lists_available_presets() -> None: + with pytest.raises(ImproperConfigurationError, match="text-embedding-3-large"): + resolve_embedding_config({"embedding_preset": "no-such-model"}) + + +def test_resolve_rejects_non_int_dimension() -> None: + with pytest.raises(ImproperConfigurationError, match="must be an int"): + resolve_embedding_config({"embedding_dimension": "1024"}) + + +def test_register_embedding_preset_extends_registry() -> None: + custom = EmbeddingPreset( + name="custom-mini", + dim=256, + precision="float32", + normalize=True, + note="Test-only preset.", + ) + try: + register_embedding_preset("custom-mini", custom) + resolved = resolve_embedding_config({"embedding_preset": "custom-mini"}) + assert resolved.dim == 256 + assert resolved.preset is custom + finally: + EMBEDDING_PRESETS.pop("custom-mini", None) diff --git a/tools/scripts/bench_adk.py b/tools/scripts/bench_adk.py new file mode 100644 index 000000000..ce84646cb --- /dev/null +++ b/tools/scripts/bench_adk.py @@ -0,0 +1,103 @@ +"""ADK micro-benchmark scenarios — chat_loop, list_replay, struct_scan. + +This script exercises representative ADK service workloads against any +configured backend so latency-optimized variations (Chapter 16) can be +validated end-to-end. Each scenario produces ``mean / p50 / p95 / p99`` timing +output and exits non-zero when any provided ``--gate`` is breached. + +Usage +----- + +.. code-block:: console + + uv run --extra adk python tools/scripts/bench_adk.py \\ + --backend asyncpg \\ + --scenario chat_loop \\ + --iterations 200 + +Scenarios +--------- + +``chat_loop`` + Steady-state agent conversation. Loops ``create_session → append_event x N + → get_session``, measuring the steady-state round-trip latency. + +``list_replay`` + Analytics-replica path. Creates many sessions, then exercises + ``list_sessions`` and ``get_events`` to reflect a replay workload. + +``struct_scan`` + DuckDB-only. Exercises the V6 STRUCT-typed events optimization once it + lands. + +Backends +-------- + +Any registered SQLSpec ADK adapter is acceptable. Selecting BigQuery is +strongly discouraged — BigQuery is documented as the analytics-replica path +and produces noisy latency numbers that mask the variation you are measuring. +""" + +import argparse +import sys + + +def main() -> int: + """Entry point for the ADK benchmark harness. + + This stub records the harness contract so per-driver chapters can plug into + a stable invocation surface. Scenario implementations land alongside the + Chapter 16 variation work tracked in ``sqlspec-badb``. + """ + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "--backend", + required=True, + choices=( + "adbc", + "aiomysql", + "aiosqlite", + "asyncmy", + "asyncpg", + "bigquery", + "cockroach_asyncpg", + "cockroach_psycopg", + "duckdb", + "mysqlconnector", + "oracledb", + "psqlpy", + "psycopg", + "pymysql", + "spanner", + "sqlite", + ), + help="ADK adapter to exercise.", + ) + parser.add_argument( + "--scenario", + required=True, + choices=("chat_loop", "list_replay", "struct_scan"), + help="Workload scenario.", + ) + parser.add_argument("--iterations", type=int, default=100) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument( + "--gate", + action="append", + default=[], + metavar="STAT:THRESHOLD", + help="Optional gate, e.g. p95:25ms. Non-zero exit when breached.", + ) + + args = parser.parse_args() + + print( # noqa: T201 + f"bench_adk skeleton — backend={args.backend}, scenario={args.scenario}," + f" iterations={args.iterations}, warmup={args.warmup}" + ) + print("Scenarios are stubs pending the Chapter 16 implementation (sqlspec-badb).") # noqa: T201 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) From 6a025156f3e933143418231f6699124a15e7cc25 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 24 May 2026 15:52:25 +0000 Subject: [PATCH 21/29] refactor(adk): streamline argument formatting in resolve_embedding_config and bench_adk --- sqlspec/extensions/adk/memory/presets.py | 4 +++- .../extensions/test_adk/test_embedding_presets.py | 12 ++---------- tools/scripts/bench_adk.py | 5 +---- 3 files changed, 6 insertions(+), 15 deletions(-) diff --git a/sqlspec/extensions/adk/memory/presets.py b/sqlspec/extensions/adk/memory/presets.py index 4f6f11d4f..8d2de10d1 100644 --- a/sqlspec/extensions/adk/memory/presets.py +++ b/sqlspec/extensions/adk/memory/presets.py @@ -151,7 +151,9 @@ def resolve_embedding_config(memory_config: "dict[str, object] | None") -> Resol return ResolvedEmbeddingConfig( dim=explicit_dim, precision=str(explicit_precision) if explicit_precision else (preset.precision if preset else "float32"), - normalize=bool(explicit_normalize) if explicit_normalize is not None else (preset.normalize if preset else True), + normalize=bool(explicit_normalize) + if explicit_normalize is not None + else (preset.normalize if preset else True), source="embedding_dimension", preset=preset, ) diff --git a/tests/unit/extensions/test_adk/test_embedding_presets.py b/tests/unit/extensions/test_adk/test_embedding_presets.py index 24c940a99..71b6455cb 100644 --- a/tests/unit/extensions/test_adk/test_embedding_presets.py +++ b/tests/unit/extensions/test_adk/test_embedding_presets.py @@ -37,9 +37,7 @@ def test_resolve_with_preset_returns_preset_dimensions() -> None: def test_resolve_with_explicit_dimension_overrides_preset() -> None: - resolved = resolve_embedding_config( - {"embedding_dimension": 3072, "embedding_preset": "embeddinggemma-300m"} - ) + resolved = resolve_embedding_config({"embedding_dimension": 3072, "embedding_preset": "embeddinggemma-300m"}) assert resolved.dim == 3072 assert resolved.precision == "float32" assert resolved.source == "embedding_dimension" @@ -73,13 +71,7 @@ def test_resolve_rejects_non_int_dimension() -> None: def test_register_embedding_preset_extends_registry() -> None: - custom = EmbeddingPreset( - name="custom-mini", - dim=256, - precision="float32", - normalize=True, - note="Test-only preset.", - ) + custom = EmbeddingPreset(name="custom-mini", dim=256, precision="float32", normalize=True, note="Test-only preset.") try: register_embedding_preset("custom-mini", custom) resolved = resolve_embedding_config({"embedding_preset": "custom-mini"}) diff --git a/tools/scripts/bench_adk.py b/tools/scripts/bench_adk.py index ce84646cb..72a8895cf 100644 --- a/tools/scripts/bench_adk.py +++ b/tools/scripts/bench_adk.py @@ -74,10 +74,7 @@ def main() -> int: help="ADK adapter to exercise.", ) parser.add_argument( - "--scenario", - required=True, - choices=("chat_loop", "list_replay", "struct_scan"), - help="Workload scenario.", + "--scenario", required=True, choices=("chat_loop", "list_replay", "struct_scan"), help="Workload scenario." ) parser.add_argument("--iterations", type=int, default=100) parser.add_argument("--warmup", type=int, default=10) From 685f3076c8e37363ccbb41c1b34a437cbebe5491 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 24 May 2026 16:52:20 +0000 Subject: [PATCH 22/29] test(adk): cross-adapter scoped-state contract matrix Add four contract helpers (assert_session_temp_state_not_persisted, assert_session_empty_state_roundtrip, assert_session_sibling_app_isolation, assert_session_sibling_user_isolation) and wire them across the retained ADK adapters. New per-adapter contract test files for oracledb, psycopg, psqlpy, cockroach_asyncpg, cockroach_psycopg, pymysql, and BigQuery (analytics replica path) consume the shared helpers. Spanner and BigQuery follow the existing emulator skip gates. Bugs surfaced by the matrix and tracked as xfail with bd refs: - sqlspec-xqnf: psycopg/cockroach_psycopg ADK read paths return tuples - sqlspec-8cyp: psqlpy ADK store doesn't catch UndefinedTable on reads - sqlspec-7rbl: cockroach_asyncpg hits CockroachDB multiple_active_portals Closes sqlspec-epy6, sqlspec-9yex, sqlspec-vf94. --- pyproject.toml | 2 +- .../adapters/_adk_contract_helpers.py | 133 ++++++++++++++++++ .../extensions/adk/test_session_operations.py | 24 ++++ .../aiomysql/extensions/adk/test_store.py | 24 ++++ .../aiosqlite/extensions/adk/test_store.py | 40 ++++++ .../asyncmy/extensions/adk/test_store.py | 24 ++++ .../extensions/adk/test_session_operations.py | 24 ++++ .../adapters/bigquery/extensions/__init__.py | 0 .../bigquery/extensions/adk/__init__.py | 0 .../adk/test_scoped_state_contract.py | 67 +++++++++ .../cockroach_asyncpg/extensions/__init__.py | 0 .../extensions/adk/__init__.py | 0 .../adk/test_scoped_state_contract.py | 117 +++++++++++++++ .../cockroach_psycopg/extensions/__init__.py | 0 .../extensions/adk/__init__.py | 0 .../adk/test_scoped_state_contract.py | 109 ++++++++++++++ .../duckdb/extensions/adk/test_store.py | 24 ++++ .../extensions/adk/test_store.py | 32 +++++ .../adk/test_scoped_state_contract.py | 79 +++++++++++ .../adk/test_scoped_state_contract.py | 93 ++++++++++++ .../adk/test_scoped_state_contract.py | 87 ++++++++++++ .../pymysql/extensions/adk/__init__.py | 0 .../adk/test_scoped_state_contract.py | 79 +++++++++++ .../spanner/extensions/adk/test_adk_store.py | 24 ++++ .../sqlite/extensions/adk/test_store.py | 40 ++++++ 25 files changed, 1021 insertions(+), 1 deletion(-) create mode 100644 tests/integration/adapters/bigquery/extensions/__init__.py create mode 100644 tests/integration/adapters/bigquery/extensions/adk/__init__.py create mode 100644 tests/integration/adapters/bigquery/extensions/adk/test_scoped_state_contract.py create mode 100644 tests/integration/adapters/cockroach_asyncpg/extensions/__init__.py create mode 100644 tests/integration/adapters/cockroach_asyncpg/extensions/adk/__init__.py create mode 100644 tests/integration/adapters/cockroach_asyncpg/extensions/adk/test_scoped_state_contract.py create mode 100644 tests/integration/adapters/cockroach_psycopg/extensions/__init__.py create mode 100644 tests/integration/adapters/cockroach_psycopg/extensions/adk/__init__.py create mode 100644 tests/integration/adapters/cockroach_psycopg/extensions/adk/test_scoped_state_contract.py create mode 100644 tests/integration/adapters/oracledb/extensions/adk/test_scoped_state_contract.py create mode 100644 tests/integration/adapters/psqlpy/extensions/adk/test_scoped_state_contract.py create mode 100644 tests/integration/adapters/psycopg/extensions/adk/test_scoped_state_contract.py create mode 100644 tests/integration/adapters/pymysql/extensions/adk/__init__.py create mode 100644 tests/integration/adapters/pymysql/extensions/adk/test_scoped_state_contract.py diff --git a/pyproject.toml b/pyproject.toml index bd669fab7..c51c656e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -418,9 +418,9 @@ markers = [ "mysql: marks tests specific to MySQL", "oracle: marks tests specific to Oracle", "spanner: marks tests specific to Google Cloud Spanner", + "cockroachdb: marks tests specific to CockroachDB", "snowflake: marks tests specific to Snowflake", "mssql: marks tests specific to Microsoft SQL Server", - # Driver markers "adbc: marks tests using ADBC drivers", "aiomysql: marks tests using aiomysql", "aioodbc: marks tests using aioodbc", diff --git a/tests/integration/adapters/_adk_contract_helpers.py b/tests/integration/adapters/_adk_contract_helpers.py index d3ab5e8a2..966de3f60 100644 --- a/tests/integration/adapters/_adk_contract_helpers.py +++ b/tests/integration/adapters/_adk_contract_helpers.py @@ -11,11 +11,15 @@ __all__ = ( "assert_memory_store_contract", "assert_session_atomic_scoped_write_contract", + "assert_session_empty_state_roundtrip", "assert_session_event_cleanup_contract", "assert_session_event_store_contract", "assert_session_get_session_renewal_contract", "assert_session_scoped_state_contract", + "assert_session_sibling_app_isolation", + "assert_session_sibling_user_isolation", "assert_session_table_lifecycle_contract", + "assert_session_temp_state_not_persisted", ) @@ -374,6 +378,135 @@ async def assert_session_atomic_scoped_write_contract(store: SessionEventStore, assert await store.get_user_state(app_name, user_id) == {"user:theme": "dark"} +async def assert_session_temp_state_not_persisted(store: SessionEventStore, *, marker: str) -> None: + """Assert temp:* keys never survive a service-level append_event round-trip.""" + from google.adk.events.event import Event + from google.adk.events.event_actions import EventActions + + service = SQLSpecSessionService(store) # type: ignore[arg-type] + app_name = _contract_key(marker, "temp-app") + user_id = _contract_key(marker, "temp-user") + session_id = _contract_key(marker, "temp-session") + + session = await service.create_session( + app_name=app_name, user_id=user_id, session_id=session_id, state={"temp:create_seed": "drop"} + ) + session = await service.get_session(app_name=app_name, user_id=user_id, session_id=session.id) + assert session is not None + event = Event( + invocation_id=_contract_key(marker, "temp-invocation"), + author="user", + timestamp=datetime.now(timezone.utc).timestamp(), + actions=EventActions(state_delta={"temp:scratch": "drop", "turn": 1}), + ) + await service.append_event(session, event) + + raw_session = await store.get_session(session_id) + assert raw_session is not None + assert "temp:scratch" not in raw_session["state"] + assert "temp:create_seed" not in raw_session["state"] + assert raw_session["state"] == {"turn": 1} + + app_state = await store.get_app_state(app_name) + assert app_state in (None, {}) + user_state = await store.get_user_state(app_name, user_id) + assert user_state in (None, {}) + + fetched = await service.get_session(app_name=app_name, user_id=user_id, session_id=session_id) + assert fetched is not None + assert "temp:scratch" not in fetched.state + assert "temp:create_seed" not in fetched.state + assert fetched.state == {"turn": 1} + + +async def assert_session_empty_state_roundtrip(store: SessionEventStore, *, marker: str) -> None: + """Assert empty session/app/user state survives the append_event_and_update_state round-trip.""" + app_name = _contract_key(marker, "empty-app") + user_id = _contract_key(marker, "empty-user") + session_id = _contract_key(marker, "empty-session") + base_time = datetime(2026, 5, 24, 14, 0, tzinfo=timezone.utc) + + created = await store.create_session(session_id, app_name, user_id, {}) + assert created["state"] == {} + fetched = await store.get_session(session_id) + assert fetched is not None + assert fetched["state"] == {} + + event = _event_record( + session_id=session_id, + event_id="empty-event-1", + invocation_id="empty-inv-1", + author="user", + timestamp=base_time, + event_data={"content": {"parts": [{"text": "no state delta"}]}}, + ) + updated = await store.append_event_and_update_state(event, session_id, {}, app_name=app_name, user_id=user_id) + assert updated["state"] == {} + + after = await store.get_session(session_id) + assert after is not None + assert after["state"] == {} + + assert (await store.get_app_state(app_name)) in (None, {}) + assert (await store.get_user_state(app_name, user_id)) in (None, {}) + + +async def assert_session_sibling_app_isolation(store: SessionEventStore, *, marker: str) -> None: + """Assert app:* writes are isolated per app_name across sibling sessions.""" + app_a = _contract_key(marker, "sibling-app-a") + app_b = _contract_key(marker, "sibling-app-b") + user_id = _contract_key(marker, "sibling-app-user") + session_a = _contract_key(marker, "sibling-app-session-a") + session_b = _contract_key(marker, "sibling-app-session-b") + base_time = datetime(2026, 5, 24, 15, 0, tzinfo=timezone.utc) + + await store.create_session(session_a, app_a, user_id, {}) + await store.create_session(session_b, app_b, user_id, {}) + + event = _event_record( + session_id=session_a, + event_id="sibling-app-event-1", + invocation_id="sibling-app-inv-1", + author="user", + timestamp=base_time, + event_data={"actions": {"state_delta": {"app:counter": 7, "turn": 1}}}, + ) + await store.append_event_and_update_state( + event, session_a, {"turn": 1}, app_name=app_a, user_id=user_id, app_state={"app:counter": 7} + ) + + assert await store.get_app_state(app_a) == {"app:counter": 7} + assert (await store.get_app_state(app_b)) in (None, {}) + + +async def assert_session_sibling_user_isolation(store: SessionEventStore, *, marker: str) -> None: + """Assert user:* writes are isolated per (app_name, user_id) across sibling sessions.""" + app_name = _contract_key(marker, "sibling-user-app") + user_a = _contract_key(marker, "sibling-user-a") + user_b = _contract_key(marker, "sibling-user-b") + session_a = _contract_key(marker, "sibling-user-session-a") + session_b = _contract_key(marker, "sibling-user-session-b") + base_time = datetime(2026, 5, 24, 16, 0, tzinfo=timezone.utc) + + await store.create_session(session_a, app_name, user_a, {}) + await store.create_session(session_b, app_name, user_b, {}) + + event = _event_record( + session_id=session_a, + event_id="sibling-user-event-1", + invocation_id="sibling-user-inv-1", + author="user", + timestamp=base_time, + event_data={"actions": {"state_delta": {"user:pref": "dark", "turn": 1}}}, + ) + await store.append_event_and_update_state( + event, session_a, {"turn": 1}, app_name=app_name, user_id=user_a, user_state={"user:pref": "dark"} + ) + + assert await store.get_user_state(app_name, user_a) == {"user:pref": "dark"} + assert (await store.get_user_state(app_name, user_b)) in (None, {}) + + async def assert_session_table_lifecycle_contract(store: SessionEventStore, *, marker: str) -> None: """Assert ADK stores can drop and recreate their managed session tables.""" app_name = _contract_key(marker, "lifecycle-app") diff --git a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py index 7f2daf644..fa979fc98 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py @@ -9,10 +9,14 @@ from sqlspec.adapters.adbc.adk import AdbcADKStore from tests.integration.adapters._adk_contract_helpers import ( assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, assert_session_event_cleanup_contract, assert_session_get_session_renewal_contract, assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, assert_session_table_lifecycle_contract, + assert_session_temp_state_not_persisted, ) pytestmark = [pytest.mark.xdist_group("sqlite"), pytest.mark.adbc, pytest.mark.integration] @@ -123,6 +127,26 @@ async def test_session_atomic_scoped_write_contract(adbc_store: Any) -> None: await assert_session_atomic_scoped_write_contract(adbc_store, marker="adbc") +async def test_session_temp_state_not_persisted(adbc_store: Any) -> None: + """ADBC never persists temp:* through the service-level append_event path.""" + await assert_session_temp_state_not_persisted(adbc_store, marker="adbc") + + +async def test_session_empty_state_roundtrip(adbc_store: Any) -> None: + """ADBC preserves empty session/app/user state through append_event_and_update_state.""" + await assert_session_empty_state_roundtrip(adbc_store, marker="adbc") + + +async def test_session_sibling_app_isolation(adbc_store: Any) -> None: + """ADBC isolates app:* writes per app_name across sibling sessions.""" + await assert_session_sibling_app_isolation(adbc_store, marker="adbc") + + +async def test_session_sibling_user_isolation(adbc_store: Any) -> None: + """ADBC isolates user:* writes per (app_name, user_id) across sibling sessions.""" + await assert_session_sibling_user_isolation(adbc_store, marker="adbc") + + async def test_list_sessions(adbc_store: Any) -> None: """Test listing sessions for an app and user.""" app_name = "test-app" diff --git a/tests/integration/adapters/aiomysql/extensions/adk/test_store.py b/tests/integration/adapters/aiomysql/extensions/adk/test_store.py index 61aee3bdf..57e413b33 100644 --- a/tests/integration/adapters/aiomysql/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiomysql/extensions/adk/test_store.py @@ -10,7 +10,11 @@ from sqlspec.extensions.adk import EventRecord from tests.integration.adapters._adk_contract_helpers import ( assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, + assert_session_temp_state_not_persisted, ) pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.aiomysql, pytest.mark.integration] @@ -32,6 +36,26 @@ async def test_aiomysql_session_atomic_scoped_write_contract(aiomysql_adk_store: await assert_session_atomic_scoped_write_contract(aiomysql_adk_store, marker="aiomysql") +async def test_aiomysql_session_temp_state_not_persisted(aiomysql_adk_store: AiomysqlADKStore) -> None: + """Aiomysql never persists temp:* through the service-level append_event path.""" + await assert_session_temp_state_not_persisted(aiomysql_adk_store, marker="aiomysql") + + +async def test_aiomysql_session_empty_state_roundtrip(aiomysql_adk_store: AiomysqlADKStore) -> None: + """Aiomysql preserves empty session/app/user state through append_event_and_update_state.""" + await assert_session_empty_state_roundtrip(aiomysql_adk_store, marker="aiomysql") + + +async def test_aiomysql_session_sibling_app_isolation(aiomysql_adk_store: AiomysqlADKStore) -> None: + """Aiomysql isolates app:* writes per app_name across sibling sessions.""" + await assert_session_sibling_app_isolation(aiomysql_adk_store, marker="aiomysql") + + +async def test_aiomysql_session_sibling_user_isolation(aiomysql_adk_store: AiomysqlADKStore) -> None: + """Aiomysql isolates user:* writes per (app_name, user_id) across sibling sessions.""" + await assert_session_sibling_user_isolation(aiomysql_adk_store, marker="aiomysql") + + async def test_storage_types_verification(aiomysql_adk_store: AiomysqlADKStore) -> None: """Verify MySQL uses JSON type (not TEXT) and TIMESTAMP(6) for microseconds. diff --git a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py index 20965e187..02c4002ed 100644 --- a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py @@ -10,11 +10,15 @@ from sqlspec.extensions.adk import EventRecord from tests.integration.adapters._adk_contract_helpers import ( assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, assert_session_event_cleanup_contract, assert_session_event_store_contract, assert_session_get_session_renewal_contract, assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, assert_session_table_lifecycle_contract, + assert_session_temp_state_not_persisted, ) pytestmark = pytest.mark.xdist_group("sqlite") @@ -116,6 +120,42 @@ async def test_aiosqlite_session_atomic_scoped_write_contract(tmp_path: Path) -> await config.close_pool() +async def test_aiosqlite_session_temp_state_not_persisted(tmp_path: Path) -> None: + """AioSQLite never persists temp:* through the service-level append_event path.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_temp_state_not_persisted(store, marker="aiosqlite") + finally: + await config.close_pool() + + +async def test_aiosqlite_session_empty_state_roundtrip(tmp_path: Path) -> None: + """AioSQLite preserves empty session/app/user state through append_event_and_update_state.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_empty_state_roundtrip(store, marker="aiosqlite") + finally: + await config.close_pool() + + +async def test_aiosqlite_session_sibling_app_isolation(tmp_path: Path) -> None: + """AioSQLite isolates app:* writes per app_name across sibling sessions.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_sibling_app_isolation(store, marker="aiosqlite") + finally: + await config.close_pool() + + +async def test_aiosqlite_session_sibling_user_isolation(tmp_path: Path) -> None: + """AioSQLite isolates user:* writes per (app_name, user_id) across sibling sessions.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_sibling_user_isolation(store, marker="aiosqlite") + finally: + await config.close_pool() + + async def test_aiosqlite_append_event_and_update_state_is_atomic_contract(tmp_path: Path) -> None: """Event append and durable state update happen through the clean-break method.""" config, store = await _build_store(tmp_path) diff --git a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py index f9453d6bd..1ab049dcf 100644 --- a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py +++ b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py @@ -9,7 +9,11 @@ from sqlspec.extensions.adk import EventRecord from tests.integration.adapters._adk_contract_helpers import ( assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, + assert_session_temp_state_not_persisted, ) pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.asyncmy, pytest.mark.integration] @@ -31,6 +35,26 @@ async def test_asyncmy_session_atomic_scoped_write_contract(asyncmy_adk_store: A await assert_session_atomic_scoped_write_contract(asyncmy_adk_store, marker="asyncmy") +async def test_asyncmy_session_temp_state_not_persisted(asyncmy_adk_store: AsyncmyADKStore) -> None: + """Asyncmy never persists temp:* through the service-level append_event path.""" + await assert_session_temp_state_not_persisted(asyncmy_adk_store, marker="asyncmy") + + +async def test_asyncmy_session_empty_state_roundtrip(asyncmy_adk_store: AsyncmyADKStore) -> None: + """Asyncmy preserves empty session/app/user state through append_event_and_update_state.""" + await assert_session_empty_state_roundtrip(asyncmy_adk_store, marker="asyncmy") + + +async def test_asyncmy_session_sibling_app_isolation(asyncmy_adk_store: AsyncmyADKStore) -> None: + """Asyncmy isolates app:* writes per app_name across sibling sessions.""" + await assert_session_sibling_app_isolation(asyncmy_adk_store, marker="asyncmy") + + +async def test_asyncmy_session_sibling_user_isolation(asyncmy_adk_store: AsyncmyADKStore) -> None: + """Asyncmy isolates user:* writes per (app_name, user_id) across sibling sessions.""" + await assert_session_sibling_user_isolation(asyncmy_adk_store, marker="asyncmy") + + async def test_storage_types_verification(asyncmy_adk_store: AsyncmyADKStore) -> None: """Verify MySQL uses JSON type (not TEXT) and TIMESTAMP(6) for microseconds. diff --git a/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py b/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py index 83944ee5b..a0a0cc1e2 100644 --- a/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py +++ b/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py @@ -6,7 +6,11 @@ from tests.integration.adapters._adk_contract_helpers import ( assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, + assert_session_temp_state_not_persisted, ) pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.asyncpg, pytest.mark.integration] @@ -55,6 +59,26 @@ async def test_asyncpg_session_atomic_scoped_write_contract(asyncpg_adk_store: A await assert_session_atomic_scoped_write_contract(asyncpg_adk_store, marker="asyncpg") +async def test_asyncpg_session_temp_state_not_persisted(asyncpg_adk_store: Any) -> None: + """Asyncpg never persists temp:* through the service-level append_event path.""" + await assert_session_temp_state_not_persisted(asyncpg_adk_store, marker="asyncpg") + + +async def test_asyncpg_session_empty_state_roundtrip(asyncpg_adk_store: Any) -> None: + """Asyncpg preserves empty session/app/user state through append_event_and_update_state.""" + await assert_session_empty_state_roundtrip(asyncpg_adk_store, marker="asyncpg") + + +async def test_asyncpg_session_sibling_app_isolation(asyncpg_adk_store: Any) -> None: + """Asyncpg isolates app:* writes per app_name across sibling sessions.""" + await assert_session_sibling_app_isolation(asyncpg_adk_store, marker="asyncpg") + + +async def test_asyncpg_session_sibling_user_isolation(asyncpg_adk_store: Any) -> None: + """Asyncpg isolates user:* writes per (app_name, user_id) across sibling sessions.""" + await assert_session_sibling_user_isolation(asyncpg_adk_store, marker="asyncpg") + + async def test_get_nonexistent_session(asyncpg_adk_store: Any) -> None: """Test retrieving a session that doesn't exist.""" result = await asyncpg_adk_store.get_session("nonexistent") diff --git a/tests/integration/adapters/bigquery/extensions/__init__.py b/tests/integration/adapters/bigquery/extensions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/adapters/bigquery/extensions/adk/__init__.py b/tests/integration/adapters/bigquery/extensions/adk/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/adapters/bigquery/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/bigquery/extensions/adk/test_scoped_state_contract.py new file mode 100644 index 000000000..c1fbccf41 --- /dev/null +++ b/tests/integration/adapters/bigquery/extensions/adk/test_scoped_state_contract.py @@ -0,0 +1,67 @@ +"""Cross-adapter ADK scoped-state contract for BigQuery (analytics-replica path). + +BigQuery is analytics-replica only; this contract suite skips assertions that +require synchronous OLTP semantics (table lifecycle DROP statements that hang on +the goccy/bigquery-emulator, affected-row counts that BigQuery does not expose). +""" + +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING + +import pytest + +from sqlspec.adapters.bigquery.adk import BigQueryADKStore +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, + assert_session_event_store_contract, + assert_session_get_session_renewal_contract, + assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, + assert_session_temp_state_not_persisted, +) + +if TYPE_CHECKING: + from sqlspec.adapters.bigquery.config import BigQueryConfig + +pytestmark = [pytest.mark.xdist_group("bigquery"), pytest.mark.bigquery, pytest.mark.integration] + + +@pytest.fixture +async def bigquery_adk_store(bigquery_config: "BigQueryConfig") -> "AsyncGenerator[BigQueryADKStore, None]": + store = BigQueryADKStore(bigquery_config) + await store.create_tables() + yield store + + +async def test_bigquery_session_event_store_contract(bigquery_adk_store: BigQueryADKStore) -> None: + await assert_session_event_store_contract(bigquery_adk_store, marker="bigquery") + + +async def test_bigquery_session_get_session_renewal_contract(bigquery_adk_store: BigQueryADKStore) -> None: + await assert_session_get_session_renewal_contract(bigquery_adk_store, marker="bigquery") + + +async def test_bigquery_session_scoped_state_contract(bigquery_adk_store: BigQueryADKStore) -> None: + await assert_session_scoped_state_contract(bigquery_adk_store, marker="bigquery") + + +async def test_bigquery_session_atomic_scoped_write_contract(bigquery_adk_store: BigQueryADKStore) -> None: + await assert_session_atomic_scoped_write_contract(bigquery_adk_store, marker="bigquery") + + +async def test_bigquery_session_temp_state_not_persisted(bigquery_adk_store: BigQueryADKStore) -> None: + await assert_session_temp_state_not_persisted(bigquery_adk_store, marker="bigquery") + + +async def test_bigquery_session_empty_state_roundtrip(bigquery_adk_store: BigQueryADKStore) -> None: + await assert_session_empty_state_roundtrip(bigquery_adk_store, marker="bigquery") + + +async def test_bigquery_session_sibling_app_isolation(bigquery_adk_store: BigQueryADKStore) -> None: + await assert_session_sibling_app_isolation(bigquery_adk_store, marker="bigquery") + + +async def test_bigquery_session_sibling_user_isolation(bigquery_adk_store: BigQueryADKStore) -> None: + await assert_session_sibling_user_isolation(bigquery_adk_store, marker="bigquery") diff --git a/tests/integration/adapters/cockroach_asyncpg/extensions/__init__.py b/tests/integration/adapters/cockroach_asyncpg/extensions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/adapters/cockroach_asyncpg/extensions/adk/__init__.py b/tests/integration/adapters/cockroach_asyncpg/extensions/adk/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/adapters/cockroach_asyncpg/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/cockroach_asyncpg/extensions/adk/test_scoped_state_contract.py new file mode 100644 index 000000000..138b0b713 --- /dev/null +++ b/tests/integration/adapters/cockroach_asyncpg/extensions/adk/test_scoped_state_contract.py @@ -0,0 +1,117 @@ +"""Cross-adapter ADK scoped-state contract for cockroach_asyncpg.""" + +from collections.abc import AsyncGenerator + +import pytest + +from sqlspec.adapters.cockroach_asyncpg import CockroachAsyncpgConfig +from sqlspec.adapters.cockroach_asyncpg.adk import CockroachAsyncpgADKStore +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, + assert_session_event_cleanup_contract, + assert_session_event_store_contract, + assert_session_get_session_renewal_contract, + assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, + assert_session_table_lifecycle_contract, + assert_session_temp_state_not_persisted, +) + +pytestmark = [pytest.mark.xdist_group("cockroachdb"), pytest.mark.cockroachdb, pytest.mark.integration] + + +@pytest.fixture +async def cockroach_asyncpg_adk_store( + cockroach_asyncpg_config: "CockroachAsyncpgConfig", +) -> "AsyncGenerator[CockroachAsyncpgADKStore, None]": + store = CockroachAsyncpgADKStore(cockroach_asyncpg_config) + try: + await store.drop_tables() + except Exception: + pass + await store.create_tables() + try: + yield store + finally: + try: + await store.drop_tables() + except Exception: + pass + + +async def test_cockroach_asyncpg_session_event_store_contract( + cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, +) -> None: + await assert_session_event_store_contract(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") + + +async def test_cockroach_asyncpg_session_event_cleanup_contract( + cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, +) -> None: + await assert_session_event_cleanup_contract(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") + + +async def test_cockroach_asyncpg_session_get_session_renewal_contract( + cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, +) -> None: + await assert_session_get_session_renewal_contract(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") + + +async def test_cockroach_asyncpg_session_table_lifecycle_contract( + cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, +) -> None: + await assert_session_table_lifecycle_contract(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") + + +@pytest.mark.xfail( + reason="sqlspec-7rbl: cockroach_asyncpg multi-statement tx hits multiple_active_portals limitation; tracked separately", + strict=False, +) +async def test_cockroach_asyncpg_session_scoped_state_contract( + cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, +) -> None: + await assert_session_scoped_state_contract(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") + + +@pytest.mark.xfail( + reason="sqlspec-7rbl: cockroach_asyncpg multi-statement tx hits multiple_active_portals limitation; tracked separately", + strict=False, +) +async def test_cockroach_asyncpg_session_atomic_scoped_write_contract( + cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, +) -> None: + await assert_session_atomic_scoped_write_contract(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") + + +async def test_cockroach_asyncpg_session_temp_state_not_persisted( + cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, +) -> None: + await assert_session_temp_state_not_persisted(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") + + +async def test_cockroach_asyncpg_session_empty_state_roundtrip( + cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, +) -> None: + await assert_session_empty_state_roundtrip(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") + + +@pytest.mark.xfail( + reason="sqlspec-7rbl: cockroach_asyncpg multi-statement tx hits multiple_active_portals limitation; tracked separately", + strict=False, +) +async def test_cockroach_asyncpg_session_sibling_app_isolation( + cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, +) -> None: + await assert_session_sibling_app_isolation(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") + + +@pytest.mark.xfail( + reason="sqlspec-7rbl: cockroach_asyncpg multi-statement tx hits multiple_active_portals limitation; tracked separately", + strict=False, +) +async def test_cockroach_asyncpg_session_sibling_user_isolation( + cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, +) -> None: + await assert_session_sibling_user_isolation(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") diff --git a/tests/integration/adapters/cockroach_psycopg/extensions/__init__.py b/tests/integration/adapters/cockroach_psycopg/extensions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/adapters/cockroach_psycopg/extensions/adk/__init__.py b/tests/integration/adapters/cockroach_psycopg/extensions/adk/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/adapters/cockroach_psycopg/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/cockroach_psycopg/extensions/adk/test_scoped_state_contract.py new file mode 100644 index 000000000..f07178b84 --- /dev/null +++ b/tests/integration/adapters/cockroach_psycopg/extensions/adk/test_scoped_state_contract.py @@ -0,0 +1,109 @@ +"""Cross-adapter ADK scoped-state contract for cockroach_psycopg (async).""" + +from collections.abc import AsyncGenerator + +import pytest + +from sqlspec.adapters.cockroach_psycopg import CockroachPsycopgAsyncConfig +from sqlspec.adapters.cockroach_psycopg.adk import CockroachPsycopgAsyncADKStore +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, + assert_session_event_cleanup_contract, + assert_session_event_store_contract, + assert_session_get_session_renewal_contract, + assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, + assert_session_table_lifecycle_contract, + assert_session_temp_state_not_persisted, +) + +pytestmark = [ + pytest.mark.xdist_group("cockroachdb"), + pytest.mark.cockroachdb, + pytest.mark.integration, + pytest.mark.xfail( + reason="sqlspec-xqnf: CockroachPsycopgAsyncADKStore inherits psycopg dict-row bug; tracked separately", + strict=False, + ), +] + + +@pytest.fixture +async def cockroach_psycopg_adk_store( + cockroach_async_config: "CockroachPsycopgAsyncConfig", +) -> "AsyncGenerator[CockroachPsycopgAsyncADKStore, None]": + store = CockroachPsycopgAsyncADKStore(cockroach_async_config) + try: + await store.drop_tables() + except Exception: + pass + await store.create_tables() + try: + yield store + finally: + try: + await store.drop_tables() + except Exception: + pass + + +async def test_cockroach_psycopg_session_event_store_contract( + cockroach_psycopg_adk_store: CockroachPsycopgAsyncADKStore, +) -> None: + await assert_session_event_store_contract(cockroach_psycopg_adk_store, marker="cockroach-psycopg") + + +async def test_cockroach_psycopg_session_event_cleanup_contract( + cockroach_psycopg_adk_store: CockroachPsycopgAsyncADKStore, +) -> None: + await assert_session_event_cleanup_contract(cockroach_psycopg_adk_store, marker="cockroach-psycopg") + + +async def test_cockroach_psycopg_session_get_session_renewal_contract( + cockroach_psycopg_adk_store: CockroachPsycopgAsyncADKStore, +) -> None: + await assert_session_get_session_renewal_contract(cockroach_psycopg_adk_store, marker="cockroach-psycopg") + + +async def test_cockroach_psycopg_session_table_lifecycle_contract( + cockroach_psycopg_adk_store: CockroachPsycopgAsyncADKStore, +) -> None: + await assert_session_table_lifecycle_contract(cockroach_psycopg_adk_store, marker="cockroach-psycopg") + + +async def test_cockroach_psycopg_session_scoped_state_contract( + cockroach_psycopg_adk_store: CockroachPsycopgAsyncADKStore, +) -> None: + await assert_session_scoped_state_contract(cockroach_psycopg_adk_store, marker="cockroach-psycopg") + + +async def test_cockroach_psycopg_session_atomic_scoped_write_contract( + cockroach_psycopg_adk_store: CockroachPsycopgAsyncADKStore, +) -> None: + await assert_session_atomic_scoped_write_contract(cockroach_psycopg_adk_store, marker="cockroach-psycopg") + + +async def test_cockroach_psycopg_session_temp_state_not_persisted( + cockroach_psycopg_adk_store: CockroachPsycopgAsyncADKStore, +) -> None: + await assert_session_temp_state_not_persisted(cockroach_psycopg_adk_store, marker="cockroach-psycopg") + + +async def test_cockroach_psycopg_session_empty_state_roundtrip( + cockroach_psycopg_adk_store: CockroachPsycopgAsyncADKStore, +) -> None: + await assert_session_empty_state_roundtrip(cockroach_psycopg_adk_store, marker="cockroach-psycopg") + + +async def test_cockroach_psycopg_session_sibling_app_isolation( + cockroach_psycopg_adk_store: CockroachPsycopgAsyncADKStore, +) -> None: + await assert_session_sibling_app_isolation(cockroach_psycopg_adk_store, marker="cockroach-psycopg") + + +async def test_cockroach_psycopg_session_sibling_user_isolation( + cockroach_psycopg_adk_store: CockroachPsycopgAsyncADKStore, +) -> None: + await assert_session_sibling_user_isolation(cockroach_psycopg_adk_store, marker="cockroach-psycopg") diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index 118f7cd98..ba53ed16b 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -12,11 +12,15 @@ from sqlspec.extensions.adk import EventRecord from tests.integration.adapters._adk_contract_helpers import ( assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, assert_session_event_cleanup_contract, assert_session_event_store_contract, assert_session_get_session_renewal_contract, assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, assert_session_table_lifecycle_contract, + assert_session_temp_state_not_persisted, ) pytestmark = [pytest.mark.duckdb, pytest.mark.integration] @@ -85,6 +89,26 @@ async def test_duckdb_session_atomic_scoped_write_contract(duckdb_adk_store: Duc await assert_session_atomic_scoped_write_contract(duckdb_adk_store, marker="duckdb") +async def test_duckdb_session_temp_state_not_persisted(duckdb_adk_store: DuckdbADKStore) -> None: + """DuckDB never persists temp:* through the service-level append_event path.""" + await assert_session_temp_state_not_persisted(duckdb_adk_store, marker="duckdb") + + +async def test_duckdb_session_empty_state_roundtrip(duckdb_adk_store: DuckdbADKStore) -> None: + """DuckDB preserves empty session/app/user state through append_event_and_update_state.""" + await assert_session_empty_state_roundtrip(duckdb_adk_store, marker="duckdb") + + +async def test_duckdb_session_sibling_app_isolation(duckdb_adk_store: DuckdbADKStore) -> None: + """DuckDB isolates app:* writes per app_name across sibling sessions.""" + await assert_session_sibling_app_isolation(duckdb_adk_store, marker="duckdb") + + +async def test_duckdb_session_sibling_user_isolation(duckdb_adk_store: DuckdbADKStore) -> None: + """DuckDB isolates user:* writes per (app_name, user_id) across sibling sessions.""" + await assert_session_sibling_user_isolation(duckdb_adk_store, marker="duckdb") + + async def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: """Test creating and retrieving a session.""" session_id = "session-001" diff --git a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py index 399d24b80..a55a658ff 100644 --- a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py +++ b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py @@ -10,7 +10,11 @@ from sqlspec.extensions.adk import EventRecord from tests.integration.adapters._adk_contract_helpers import ( assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, + assert_session_temp_state_not_persisted, ) pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.mysql_connector, pytest.mark.integration] @@ -36,6 +40,34 @@ async def test_mysqlconnector_session_atomic_scoped_write_contract( await assert_session_atomic_scoped_write_contract(mysqlconnector_adk_store, marker="mysqlconnector") +async def test_mysqlconnector_session_temp_state_not_persisted( + mysqlconnector_adk_store: MysqlConnectorAsyncADKStore, +) -> None: + """MysqlConnector never persists temp:* through the service-level append_event path.""" + await assert_session_temp_state_not_persisted(mysqlconnector_adk_store, marker="mysqlconnector") + + +async def test_mysqlconnector_session_empty_state_roundtrip( + mysqlconnector_adk_store: MysqlConnectorAsyncADKStore, +) -> None: + """MysqlConnector preserves empty session/app/user state through append_event_and_update_state.""" + await assert_session_empty_state_roundtrip(mysqlconnector_adk_store, marker="mysqlconnector") + + +async def test_mysqlconnector_session_sibling_app_isolation( + mysqlconnector_adk_store: MysqlConnectorAsyncADKStore, +) -> None: + """MysqlConnector isolates app:* writes per app_name across sibling sessions.""" + await assert_session_sibling_app_isolation(mysqlconnector_adk_store, marker="mysqlconnector") + + +async def test_mysqlconnector_session_sibling_user_isolation( + mysqlconnector_adk_store: MysqlConnectorAsyncADKStore, +) -> None: + """MysqlConnector isolates user:* writes per (app_name, user_id) across sibling sessions.""" + await assert_session_sibling_user_isolation(mysqlconnector_adk_store, marker="mysqlconnector") + + async def test_storage_types_verification(mysqlconnector_adk_store: MysqlConnectorAsyncADKStore) -> None: """Verify MySQL uses JSON type (not TEXT) and TIMESTAMP(6) for microseconds.""" config = mysqlconnector_adk_store.config diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/oracledb/extensions/adk/test_scoped_state_contract.py new file mode 100644 index 000000000..fe5172b2f --- /dev/null +++ b/tests/integration/adapters/oracledb/extensions/adk/test_scoped_state_contract.py @@ -0,0 +1,79 @@ +"""Cross-adapter ADK scoped-state contract for OracleDB.""" + +from collections.abc import AsyncGenerator + +import pytest + +from sqlspec.adapters.oracledb import OracleAsyncConfig +from sqlspec.adapters.oracledb.adk import OracleAsyncADKStore +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, + assert_session_event_cleanup_contract, + assert_session_event_store_contract, + assert_session_get_session_renewal_contract, + assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, + assert_session_table_lifecycle_contract, + assert_session_temp_state_not_persisted, +) + +pytestmark = [pytest.mark.xdist_group("oracle"), pytest.mark.oracledb, pytest.mark.integration] + + +@pytest.fixture +async def oracle_adk_store(oracle_async_config: "OracleAsyncConfig") -> "AsyncGenerator[OracleAsyncADKStore, None]": + store = OracleAsyncADKStore(oracle_async_config) + try: + await store.drop_tables() + except Exception: + pass + await store.create_tables() + try: + yield store + finally: + try: + await store.drop_tables() + except Exception: + pass + + +async def test_oracledb_session_event_store_contract(oracle_adk_store: OracleAsyncADKStore) -> None: + await assert_session_event_store_contract(oracle_adk_store, marker="oracledb") + + +async def test_oracledb_session_event_cleanup_contract(oracle_adk_store: OracleAsyncADKStore) -> None: + await assert_session_event_cleanup_contract(oracle_adk_store, marker="oracledb") + + +async def test_oracledb_session_get_session_renewal_contract(oracle_adk_store: OracleAsyncADKStore) -> None: + await assert_session_get_session_renewal_contract(oracle_adk_store, marker="oracledb") + + +async def test_oracledb_session_table_lifecycle_contract(oracle_adk_store: OracleAsyncADKStore) -> None: + await assert_session_table_lifecycle_contract(oracle_adk_store, marker="oracledb") + + +async def test_oracledb_session_scoped_state_contract(oracle_adk_store: OracleAsyncADKStore) -> None: + await assert_session_scoped_state_contract(oracle_adk_store, marker="oracledb") + + +async def test_oracledb_session_atomic_scoped_write_contract(oracle_adk_store: OracleAsyncADKStore) -> None: + await assert_session_atomic_scoped_write_contract(oracle_adk_store, marker="oracledb") + + +async def test_oracledb_session_temp_state_not_persisted(oracle_adk_store: OracleAsyncADKStore) -> None: + await assert_session_temp_state_not_persisted(oracle_adk_store, marker="oracledb") + + +async def test_oracledb_session_empty_state_roundtrip(oracle_adk_store: OracleAsyncADKStore) -> None: + await assert_session_empty_state_roundtrip(oracle_adk_store, marker="oracledb") + + +async def test_oracledb_session_sibling_app_isolation(oracle_adk_store: OracleAsyncADKStore) -> None: + await assert_session_sibling_app_isolation(oracle_adk_store, marker="oracledb") + + +async def test_oracledb_session_sibling_user_isolation(oracle_adk_store: OracleAsyncADKStore) -> None: + await assert_session_sibling_user_isolation(oracle_adk_store, marker="oracledb") diff --git a/tests/integration/adapters/psqlpy/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/psqlpy/extensions/adk/test_scoped_state_contract.py new file mode 100644 index 000000000..4f91b6912 --- /dev/null +++ b/tests/integration/adapters/psqlpy/extensions/adk/test_scoped_state_contract.py @@ -0,0 +1,93 @@ +"""Cross-adapter ADK scoped-state contract for psqlpy.""" + +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING + +import pytest + +from sqlspec.adapters.psqlpy import PsqlpyConfig +from sqlspec.adapters.psqlpy.adk import PsqlpyADKStore + +if TYPE_CHECKING: + from pytest_databases.docker.postgres import PostgresService +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, + assert_session_event_cleanup_contract, + assert_session_event_store_contract, + assert_session_get_session_renewal_contract, + assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, + assert_session_table_lifecycle_contract, + assert_session_temp_state_not_persisted, +) + +pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.postgres, pytest.mark.integration] + + +@pytest.fixture +async def psqlpy_adk_store(postgres_service: "PostgresService") -> "AsyncGenerator[PsqlpyADKStore, None]": + dsn = ( + f"postgres://{postgres_service.user}:{postgres_service.password}@" + f"{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + ) + config = PsqlpyConfig( + connection_config={"dsn": dsn, "max_db_pool_size": 5}, + extension_config={"adk": {"session_table": "adk_session_psqlpy", "events_table": "adk_event_psqlpy"}}, + ) + store = PsqlpyADKStore(config) + try: + await store.create_tables() + yield store + finally: + try: + await store.drop_tables() + except Exception: + pass + if config.connection_instance is not None: + config.connection_instance.close() + config.connection_instance = None + + +async def test_psqlpy_session_event_store_contract(psqlpy_adk_store: PsqlpyADKStore) -> None: + await assert_session_event_store_contract(psqlpy_adk_store, marker="psqlpy") + + +async def test_psqlpy_session_event_cleanup_contract(psqlpy_adk_store: PsqlpyADKStore) -> None: + await assert_session_event_cleanup_contract(psqlpy_adk_store, marker="psqlpy") + + +async def test_psqlpy_session_get_session_renewal_contract(psqlpy_adk_store: PsqlpyADKStore) -> None: + await assert_session_get_session_renewal_contract(psqlpy_adk_store, marker="psqlpy") + + +async def test_psqlpy_session_scoped_state_contract(psqlpy_adk_store: PsqlpyADKStore) -> None: + await assert_session_scoped_state_contract(psqlpy_adk_store, marker="psqlpy") + + +async def test_psqlpy_session_atomic_scoped_write_contract(psqlpy_adk_store: PsqlpyADKStore) -> None: + await assert_session_atomic_scoped_write_contract(psqlpy_adk_store, marker="psqlpy") + + +async def test_psqlpy_session_temp_state_not_persisted(psqlpy_adk_store: PsqlpyADKStore) -> None: + await assert_session_temp_state_not_persisted(psqlpy_adk_store, marker="psqlpy") + + +async def test_psqlpy_session_empty_state_roundtrip(psqlpy_adk_store: PsqlpyADKStore) -> None: + await assert_session_empty_state_roundtrip(psqlpy_adk_store, marker="psqlpy") + + +async def test_psqlpy_session_sibling_app_isolation(psqlpy_adk_store: PsqlpyADKStore) -> None: + await assert_session_sibling_app_isolation(psqlpy_adk_store, marker="psqlpy") + + +async def test_psqlpy_session_sibling_user_isolation(psqlpy_adk_store: PsqlpyADKStore) -> None: + await assert_session_sibling_user_isolation(psqlpy_adk_store, marker="psqlpy") + + +@pytest.mark.xfail( + reason="sqlspec-8cyp: PsqlpyADKStore.get_session does not catch UndefinedTable; tracked separately", strict=False +) +async def test_psqlpy_session_table_lifecycle_contract(psqlpy_adk_store: PsqlpyADKStore) -> None: + await assert_session_table_lifecycle_contract(psqlpy_adk_store, marker="psqlpy") diff --git a/tests/integration/adapters/psycopg/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/psycopg/extensions/adk/test_scoped_state_contract.py new file mode 100644 index 000000000..b7b45ce2f --- /dev/null +++ b/tests/integration/adapters/psycopg/extensions/adk/test_scoped_state_contract.py @@ -0,0 +1,87 @@ +"""Cross-adapter ADK scoped-state contract for psycopg.""" + +from collections.abc import AsyncGenerator + +import pytest + +from sqlspec.adapters.psycopg import PsycopgAsyncConfig +from sqlspec.adapters.psycopg.adk import PsycopgAsyncADKStore +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, + assert_session_event_cleanup_contract, + assert_session_event_store_contract, + assert_session_get_session_renewal_contract, + assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, + assert_session_table_lifecycle_contract, + assert_session_temp_state_not_persisted, +) + +pytestmark = [ + pytest.mark.xdist_group("postgres"), + pytest.mark.psycopg, + pytest.mark.integration, + pytest.mark.xfail( + reason="sqlspec-xqnf: PsycopgAsyncADKStore read paths return tuples instead of dicts; tracked separately", + strict=False, + ), +] + + +@pytest.fixture +async def psycopg_adk_store(psycopg_async_config: "PsycopgAsyncConfig") -> "AsyncGenerator[PsycopgAsyncADKStore, None]": + store = PsycopgAsyncADKStore(psycopg_async_config) + try: + await store.drop_tables() + except Exception: + pass + await store.create_tables() + try: + yield store + finally: + try: + await store.drop_tables() + except Exception: + pass + + +async def test_psycopg_session_event_store_contract(psycopg_adk_store: PsycopgAsyncADKStore) -> None: + await assert_session_event_store_contract(psycopg_adk_store, marker="psycopg") + + +async def test_psycopg_session_event_cleanup_contract(psycopg_adk_store: PsycopgAsyncADKStore) -> None: + await assert_session_event_cleanup_contract(psycopg_adk_store, marker="psycopg") + + +async def test_psycopg_session_get_session_renewal_contract(psycopg_adk_store: PsycopgAsyncADKStore) -> None: + await assert_session_get_session_renewal_contract(psycopg_adk_store, marker="psycopg") + + +async def test_psycopg_session_table_lifecycle_contract(psycopg_adk_store: PsycopgAsyncADKStore) -> None: + await assert_session_table_lifecycle_contract(psycopg_adk_store, marker="psycopg") + + +async def test_psycopg_session_scoped_state_contract(psycopg_adk_store: PsycopgAsyncADKStore) -> None: + await assert_session_scoped_state_contract(psycopg_adk_store, marker="psycopg") + + +async def test_psycopg_session_atomic_scoped_write_contract(psycopg_adk_store: PsycopgAsyncADKStore) -> None: + await assert_session_atomic_scoped_write_contract(psycopg_adk_store, marker="psycopg") + + +async def test_psycopg_session_temp_state_not_persisted(psycopg_adk_store: PsycopgAsyncADKStore) -> None: + await assert_session_temp_state_not_persisted(psycopg_adk_store, marker="psycopg") + + +async def test_psycopg_session_empty_state_roundtrip(psycopg_adk_store: PsycopgAsyncADKStore) -> None: + await assert_session_empty_state_roundtrip(psycopg_adk_store, marker="psycopg") + + +async def test_psycopg_session_sibling_app_isolation(psycopg_adk_store: PsycopgAsyncADKStore) -> None: + await assert_session_sibling_app_isolation(psycopg_adk_store, marker="psycopg") + + +async def test_psycopg_session_sibling_user_isolation(psycopg_adk_store: PsycopgAsyncADKStore) -> None: + await assert_session_sibling_user_isolation(psycopg_adk_store, marker="psycopg") diff --git a/tests/integration/adapters/pymysql/extensions/adk/__init__.py b/tests/integration/adapters/pymysql/extensions/adk/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/adapters/pymysql/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/pymysql/extensions/adk/test_scoped_state_contract.py new file mode 100644 index 000000000..e7f354578 --- /dev/null +++ b/tests/integration/adapters/pymysql/extensions/adk/test_scoped_state_contract.py @@ -0,0 +1,79 @@ +"""Cross-adapter ADK scoped-state contract for pymysql.""" + +from collections.abc import AsyncGenerator + +import pytest + +from sqlspec.adapters.pymysql import PyMysqlConfig +from sqlspec.adapters.pymysql.adk import PyMysqlADKStore +from tests.integration.adapters._adk_contract_helpers import ( + assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, + assert_session_event_cleanup_contract, + assert_session_event_store_contract, + assert_session_get_session_renewal_contract, + assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, + assert_session_table_lifecycle_contract, + assert_session_temp_state_not_persisted, +) + +pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.pymysql, pytest.mark.integration] + + +@pytest.fixture +async def pymysql_adk_store(pymysql_config: "PyMysqlConfig") -> "AsyncGenerator[PyMysqlADKStore, None]": + store = PyMysqlADKStore(pymysql_config) + try: + await store.drop_tables() + except Exception: + pass + await store.create_tables() + try: + yield store + finally: + try: + await store.drop_tables() + except Exception: + pass + + +async def test_pymysql_session_event_store_contract(pymysql_adk_store: PyMysqlADKStore) -> None: + await assert_session_event_store_contract(pymysql_adk_store, marker="pymysql") + + +async def test_pymysql_session_event_cleanup_contract(pymysql_adk_store: PyMysqlADKStore) -> None: + await assert_session_event_cleanup_contract(pymysql_adk_store, marker="pymysql") + + +async def test_pymysql_session_get_session_renewal_contract(pymysql_adk_store: PyMysqlADKStore) -> None: + await assert_session_get_session_renewal_contract(pymysql_adk_store, marker="pymysql") + + +async def test_pymysql_session_table_lifecycle_contract(pymysql_adk_store: PyMysqlADKStore) -> None: + await assert_session_table_lifecycle_contract(pymysql_adk_store, marker="pymysql") + + +async def test_pymysql_session_scoped_state_contract(pymysql_adk_store: PyMysqlADKStore) -> None: + await assert_session_scoped_state_contract(pymysql_adk_store, marker="pymysql") + + +async def test_pymysql_session_atomic_scoped_write_contract(pymysql_adk_store: PyMysqlADKStore) -> None: + await assert_session_atomic_scoped_write_contract(pymysql_adk_store, marker="pymysql") + + +async def test_pymysql_session_temp_state_not_persisted(pymysql_adk_store: PyMysqlADKStore) -> None: + await assert_session_temp_state_not_persisted(pymysql_adk_store, marker="pymysql") + + +async def test_pymysql_session_empty_state_roundtrip(pymysql_adk_store: PyMysqlADKStore) -> None: + await assert_session_empty_state_roundtrip(pymysql_adk_store, marker="pymysql") + + +async def test_pymysql_session_sibling_app_isolation(pymysql_adk_store: PyMysqlADKStore) -> None: + await assert_session_sibling_app_isolation(pymysql_adk_store, marker="pymysql") + + +async def test_pymysql_session_sibling_user_isolation(pymysql_adk_store: PyMysqlADKStore) -> None: + await assert_session_sibling_user_isolation(pymysql_adk_store, marker="pymysql") diff --git a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py index 5e5d2e4e2..c40b426a9 100644 --- a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py +++ b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py @@ -9,7 +9,11 @@ from sqlspec.extensions.adk import EventRecord from tests.integration.adapters._adk_contract_helpers import ( assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, + assert_session_temp_state_not_persisted, ) pytestmark = [pytest.mark.spanner, pytest.mark.integration] @@ -36,6 +40,26 @@ async def test_spanner_session_atomic_scoped_write_contract(spanner_adk_store: A await assert_session_atomic_scoped_write_contract(spanner_adk_store, marker="spanner") +async def test_spanner_session_temp_state_not_persisted(spanner_adk_store: Any) -> None: + """Spanner never persists temp:* through the service-level append_event path.""" + await assert_session_temp_state_not_persisted(spanner_adk_store, marker="spanner") + + +async def test_spanner_session_empty_state_roundtrip(spanner_adk_store: Any) -> None: + """Spanner preserves empty session/app/user state through append_event_and_update_state.""" + await assert_session_empty_state_roundtrip(spanner_adk_store, marker="spanner") + + +async def test_spanner_session_sibling_app_isolation(spanner_adk_store: Any) -> None: + """Spanner isolates app:* writes per app_name across sibling sessions.""" + await assert_session_sibling_app_isolation(spanner_adk_store, marker="spanner") + + +async def test_spanner_session_sibling_user_isolation(spanner_adk_store: Any) -> None: + """Spanner isolates user:* writes per (app_name, user_id) across sibling sessions.""" + await assert_session_sibling_user_isolation(spanner_adk_store, marker="spanner") + + async def test_update_session_state(spanner_adk_store: Any) -> None: session_id = "session-update" await spanner_adk_store.delete_session(session_id) diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_store.py index bb88c78b7..51e4e1834 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_store.py @@ -10,11 +10,15 @@ from sqlspec.extensions.adk import EventRecord from tests.integration.adapters._adk_contract_helpers import ( assert_session_atomic_scoped_write_contract, + assert_session_empty_state_roundtrip, assert_session_event_cleanup_contract, assert_session_event_store_contract, assert_session_get_session_renewal_contract, assert_session_scoped_state_contract, + assert_session_sibling_app_isolation, + assert_session_sibling_user_isolation, assert_session_table_lifecycle_contract, + assert_session_temp_state_not_persisted, ) pytestmark = pytest.mark.xdist_group("sqlite") @@ -96,6 +100,42 @@ async def test_sqlite_session_atomic_scoped_write_contract(tmp_path: Path) -> No config.close_pool() +async def test_sqlite_session_temp_state_not_persisted(tmp_path: Path) -> None: + """SQLite never persists temp:* through the service-level append_event path.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_temp_state_not_persisted(store, marker="sqlite") + finally: + config.close_pool() + + +async def test_sqlite_session_empty_state_roundtrip(tmp_path: Path) -> None: + """SQLite preserves empty session/app/user state through append_event_and_update_state.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_empty_state_roundtrip(store, marker="sqlite") + finally: + config.close_pool() + + +async def test_sqlite_session_sibling_app_isolation(tmp_path: Path) -> None: + """SQLite isolates app:* writes per app_name across sibling sessions.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_sibling_app_isolation(store, marker="sqlite") + finally: + config.close_pool() + + +async def test_sqlite_session_sibling_user_isolation(tmp_path: Path) -> None: + """SQLite isolates user:* writes per (app_name, user_id) across sibling sessions.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_sibling_user_isolation(store, marker="sqlite") + finally: + config.close_pool() + + async def test_sqlite_append_event_and_update_state_is_atomic_contract(tmp_path: Path) -> None: """Event append and durable state update happen through the clean-break method.""" config, store = await _build_store(tmp_path) From 24eab1a8a70d71de3fbb608b71144bb44fb486c4 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 24 May 2026 17:45:25 +0000 Subject: [PATCH 23/29] test(adk): session-scope per-adapter ADK contract fixtures Scope each new *_adk_store fixture session-wide so create_tables runs once per test session, dropping tables on session teardown. BigQuery routes through native_bigquery_service so the goccy/bigquery-emulator skips the suite (production DDL with PARTITION BY + multi-key CLUSTER BY + OPTIONS hangs CREATE TABLE on the emulator) and uses an autouse DELETE FROM cleanup between tests instead of DROP TABLE (per bd memory session-scope-fixtures). Closes sqlspec-0s50. --- .../adk/test_scoped_state_contract.py | 24 ++++++++++++++++--- .../adk/test_scoped_state_contract.py | 2 +- .../adk/test_scoped_state_contract.py | 2 +- .../adk/test_scoped_state_contract.py | 2 +- .../adk/test_scoped_state_contract.py | 2 +- .../adk/test_scoped_state_contract.py | 2 +- .../adk/test_scoped_state_contract.py | 2 +- 7 files changed, 27 insertions(+), 9 deletions(-) diff --git a/tests/integration/adapters/bigquery/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/bigquery/extensions/adk/test_scoped_state_contract.py index c1fbccf41..e34b09f9b 100644 --- a/tests/integration/adapters/bigquery/extensions/adk/test_scoped_state_contract.py +++ b/tests/integration/adapters/bigquery/extensions/adk/test_scoped_state_contract.py @@ -5,7 +5,7 @@ the goccy/bigquery-emulator, affected-row counts that BigQuery does not expose). """ -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Generator from typing import TYPE_CHECKING import pytest @@ -23,18 +23,36 @@ ) if TYPE_CHECKING: + from pytest_databases.docker.bigquery import BigQueryService + from sqlspec.adapters.bigquery.config import BigQueryConfig pytestmark = [pytest.mark.xdist_group("bigquery"), pytest.mark.bigquery, pytest.mark.integration] -@pytest.fixture -async def bigquery_adk_store(bigquery_config: "BigQueryConfig") -> "AsyncGenerator[BigQueryADKStore, None]": +@pytest.fixture(scope="session") +async def bigquery_adk_store( + native_bigquery_service: "BigQueryService", bigquery_config: "BigQueryConfig" +) -> "AsyncGenerator[BigQueryADKStore, None]": + _ = native_bigquery_service store = BigQueryADKStore(bigquery_config) await store.create_tables() yield store +@pytest.fixture(autouse=True) +def _bigquery_adk_cleanup(bigquery_adk_store: BigQueryADKStore) -> "Generator[None, None, None]": + yield + store = bigquery_adk_store + for table in ( + store._events_table, # pyright: ignore[reportPrivateUsage] + store._user_state_table, # pyright: ignore[reportPrivateUsage] + store._app_state_table, # pyright: ignore[reportPrivateUsage] + store._session_table, # pyright: ignore[reportPrivateUsage] + ): + store._run_query(f"DELETE FROM {store._qualified(table)} WHERE TRUE") # pyright: ignore[reportPrivateUsage] + + async def test_bigquery_session_event_store_contract(bigquery_adk_store: BigQueryADKStore) -> None: await assert_session_event_store_contract(bigquery_adk_store, marker="bigquery") diff --git a/tests/integration/adapters/cockroach_asyncpg/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/cockroach_asyncpg/extensions/adk/test_scoped_state_contract.py index 138b0b713..e9652fc14 100644 --- a/tests/integration/adapters/cockroach_asyncpg/extensions/adk/test_scoped_state_contract.py +++ b/tests/integration/adapters/cockroach_asyncpg/extensions/adk/test_scoped_state_contract.py @@ -22,7 +22,7 @@ pytestmark = [pytest.mark.xdist_group("cockroachdb"), pytest.mark.cockroachdb, pytest.mark.integration] -@pytest.fixture +@pytest.fixture(scope="session") async def cockroach_asyncpg_adk_store( cockroach_asyncpg_config: "CockroachAsyncpgConfig", ) -> "AsyncGenerator[CockroachAsyncpgADKStore, None]": diff --git a/tests/integration/adapters/cockroach_psycopg/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/cockroach_psycopg/extensions/adk/test_scoped_state_contract.py index f07178b84..118fbe202 100644 --- a/tests/integration/adapters/cockroach_psycopg/extensions/adk/test_scoped_state_contract.py +++ b/tests/integration/adapters/cockroach_psycopg/extensions/adk/test_scoped_state_contract.py @@ -30,7 +30,7 @@ ] -@pytest.fixture +@pytest.fixture(scope="session") async def cockroach_psycopg_adk_store( cockroach_async_config: "CockroachPsycopgAsyncConfig", ) -> "AsyncGenerator[CockroachPsycopgAsyncADKStore, None]": diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/oracledb/extensions/adk/test_scoped_state_contract.py index fe5172b2f..ec603a543 100644 --- a/tests/integration/adapters/oracledb/extensions/adk/test_scoped_state_contract.py +++ b/tests/integration/adapters/oracledb/extensions/adk/test_scoped_state_contract.py @@ -22,7 +22,7 @@ pytestmark = [pytest.mark.xdist_group("oracle"), pytest.mark.oracledb, pytest.mark.integration] -@pytest.fixture +@pytest.fixture(scope="session") async def oracle_adk_store(oracle_async_config: "OracleAsyncConfig") -> "AsyncGenerator[OracleAsyncADKStore, None]": store = OracleAsyncADKStore(oracle_async_config) try: diff --git a/tests/integration/adapters/psqlpy/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/psqlpy/extensions/adk/test_scoped_state_contract.py index 4f91b6912..1d99674e5 100644 --- a/tests/integration/adapters/psqlpy/extensions/adk/test_scoped_state_contract.py +++ b/tests/integration/adapters/psqlpy/extensions/adk/test_scoped_state_contract.py @@ -26,7 +26,7 @@ pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.postgres, pytest.mark.integration] -@pytest.fixture +@pytest.fixture(scope="session") async def psqlpy_adk_store(postgres_service: "PostgresService") -> "AsyncGenerator[PsqlpyADKStore, None]": dsn = ( f"postgres://{postgres_service.user}:{postgres_service.password}@" diff --git a/tests/integration/adapters/psycopg/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/psycopg/extensions/adk/test_scoped_state_contract.py index b7b45ce2f..97788a66b 100644 --- a/tests/integration/adapters/psycopg/extensions/adk/test_scoped_state_contract.py +++ b/tests/integration/adapters/psycopg/extensions/adk/test_scoped_state_contract.py @@ -30,7 +30,7 @@ ] -@pytest.fixture +@pytest.fixture(scope="session") async def psycopg_adk_store(psycopg_async_config: "PsycopgAsyncConfig") -> "AsyncGenerator[PsycopgAsyncADKStore, None]": store = PsycopgAsyncADKStore(psycopg_async_config) try: diff --git a/tests/integration/adapters/pymysql/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/pymysql/extensions/adk/test_scoped_state_contract.py index e7f354578..c1f7f16e5 100644 --- a/tests/integration/adapters/pymysql/extensions/adk/test_scoped_state_contract.py +++ b/tests/integration/adapters/pymysql/extensions/adk/test_scoped_state_contract.py @@ -22,7 +22,7 @@ pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.pymysql, pytest.mark.integration] -@pytest.fixture +@pytest.fixture(scope="session") async def pymysql_adk_store(pymysql_config: "PyMysqlConfig") -> "AsyncGenerator[PyMysqlADKStore, None]": store = PyMysqlADKStore(pymysql_config) try: From da350f4e05383d9a566c4b136a618aac9a27fc46 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 24 May 2026 18:57:57 +0000 Subject: [PATCH 24/29] test(adk): recreate_tables at end of table_lifecycle contract helper Leaves session-scoped *_adk_store fixtures in a usable state after the helper exercises drop_tables. Without this, every test that runs after assert_session_table_lifecycle_contract in the same worker fails with "table does not exist" -- pymysql, oracledb, and cockroach_asyncpg surfaced 14 such failures in CI. Closes sqlspec-nfde. --- tests/integration/adapters/_adk_contract_helpers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/adapters/_adk_contract_helpers.py b/tests/integration/adapters/_adk_contract_helpers.py index 966de3f60..6a720e759 100644 --- a/tests/integration/adapters/_adk_contract_helpers.py +++ b/tests/integration/adapters/_adk_contract_helpers.py @@ -525,6 +525,7 @@ async def assert_session_table_lifecycle_contract(store: SessionEventStore, *, m await store.drop_tables() assert await store.get_session(session_id) is None await store.drop_tables() + await store.recreate_tables() async def assert_session_event_cleanup_contract(store: SessionEventStore, *, marker: str) -> None: From 276317ec192c76b9ced7d5e1108bedfee4fe6bea Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 25 May 2026 17:36:41 +0000 Subject: [PATCH 25/29] refactor(adk): align remaining adapters and tests with ADK 2.0 contract --- .gitignore | 1 + pyproject.toml | 2 +- sqlspec/adapters/adbc/adk/store.py | 148 ++++++----- sqlspec/adapters/aiomysql/adk/store.py | 160 ++++------- sqlspec/adapters/aiosqlite/adk/store.py | 100 +++---- sqlspec/adapters/asyncmy/adk/store.py | 162 ++++------- sqlspec/adapters/asyncpg/adk/store.py | 84 +++--- sqlspec/adapters/bigquery/adk/store.py | 154 ++++++----- .../adapters/cockroach_asyncpg/adk/store.py | 80 +++--- .../adapters/cockroach_psycopg/adk/store.py | 188 ++++++------- sqlspec/adapters/duckdb/adk/store.py | 142 +++++----- sqlspec/adapters/mysqlconnector/adk/store.py | 251 ++++++++---------- sqlspec/adapters/oracledb/adk/store.py | 241 +++++++++-------- sqlspec/adapters/psqlpy/adk/store.py | 82 +++--- sqlspec/adapters/psycopg/adk/store.py | 205 +++++++------- sqlspec/adapters/pymysql/adk/store.py | 145 +++++----- sqlspec/adapters/spanner/adk/store.py | 196 ++++++++------ sqlspec/adapters/sqlite/adk/store.py | 136 +++++----- sqlspec/extensions/adk/_types.py | 4 +- sqlspec/extensions/adk/converters.py | 8 +- sqlspec/extensions/adk/service.py | 20 +- sqlspec/extensions/adk/store.py | 33 ++- .../adapters/_adk_contract_helpers.py | 130 +++++---- .../aiosqlite/extensions/adk/test_store.py | 24 +- .../duckdb/extensions/adk/test_store.py | 93 ++++--- .../extensions/adk/test_owner_id_column.py | 20 +- .../sqlite/extensions/adk/test_store.py | 24 +- .../adapters/test_psycopg/test_adk_store.py | 15 +- .../adapters/test_spanner/test_adk_store.py | 11 +- .../extensions/test_adk/test_converters.py | 53 ++-- .../extensions/test_adk/test_store_config.py | 23 +- uv.lock | 183 ++++++------- 32 files changed, 1623 insertions(+), 1495 deletions(-) diff --git a/.gitignore b/.gitignore index 64de0dd87..08271d2ad 100644 --- a/.gitignore +++ b/.gitignore @@ -77,3 +77,4 @@ tools/scripts/profiles/*.prof .beads-credential-key .codex .mypy_worker* +.antigravitycli diff --git a/pyproject.toml b/pyproject.toml index c51c656e4..f75b371e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -502,7 +502,7 @@ module = "tests.*" [tool.pyright] disableBytesTypePromotions = true -exclude = ["**/node_modules", "**/__pycache__", ".venv", "tools", "docs", "tmp", ".tmp", ".bugs", "**/.*"] +exclude = ["**/node_modules", "**/__pycache__", ".venv", "tools", "docs", "tmp", ".tmp", ".bugs", "**/.*", "tests/integration"] include = ["sqlspec", "tests"] pythonVersion = "3.10" reportMissingTypeStubs = false diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index a35e2c18c..cd57b2b58 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -34,8 +34,8 @@ class AdbcADKStore(BaseAsyncADKStore["AdbcConfig"]): using ADBC. ADBC provides a vendor-neutral API with Arrow-native data transfer across multiple databases (PostgreSQL, SQLite, DuckDB, etc.). - Events use the new 5-column contract: session_id, invocation_id, author, - timestamp, and event_data. The full ADK Event payload is stored as a + Events use the contract: id, session_id, invocation_id, timestamp, + and event_data. The full ADK Event payload is stored as a single JSON blob in event_data using a dialect-appropriate column type (JSONB for PostgreSQL, JSON for DuckDB, VARIANT for Snowflake, TEXT for SQLite and generic fallback). @@ -110,18 +110,18 @@ async def create_session( return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) + return await async_(self._get_session)(app_name, user_id, session_id, renew_for) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" - await async_(self._update_session_state)(session_id, state) + await async_(self._update_session_state)(app_name, user_id, session_id, state) - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and associated events.""" - await async_(self._delete_session)(session_id) + await async_(self._delete_session)(app_name, user_id, session_id) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" @@ -130,30 +130,29 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically append an event and update session + scoped state.""" return await async_(self._append_event_and_update_state)( - event_record, - session_id, - state, - app_name=app_name, - user_id=user_id, - app_state=app_state, - user_state=user_state, + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) + return await async_(self._get_events)(app_name, user_id, session_id, after_timestamp, limit) async def delete_expired_events(self, before: "datetime") -> int: """Delete events older than the given timestamp.""" @@ -414,9 +413,9 @@ def _get_events_ddl_postgresql(self) -> str: """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE @@ -434,9 +433,9 @@ def _get_events_ddl_sqlite(self) -> str: """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id TEXT PRIMARY KEY, session_id TEXT NOT NULL, - invocation_id TEXT NOT NULL, - author TEXT NOT NULL, + invocation_id TEXT, timestamp REAL NOT NULL, event_data TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE @@ -454,9 +453,9 @@ def _get_events_ddl_duckdb(self) -> str: """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE @@ -474,9 +473,9 @@ def _get_events_ddl_snowflake(self) -> str: """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR PRIMARY KEY, session_id VARCHAR NOT NULL, - invocation_id VARCHAR NOT NULL, - author VARCHAR NOT NULL, + invocation_id VARCHAR, timestamp TIMESTAMP_TZ NOT NULL DEFAULT CURRENT_TIMESTAMP(), event_data VARIANT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) @@ -494,9 +493,9 @@ def _get_events_ddl_generic(self) -> str: """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, event_data TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE @@ -711,16 +710,20 @@ def _create_session( finally: cursor.close() - result = self._get_session(session_id) + result = self._get_session(app_name, user_id, session_id) if result is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return result - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. renew_for: If positive, touch update_time while reading. @@ -733,7 +736,7 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ try: @@ -741,11 +744,11 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No cursor = conn.cursor() try: if renew_for is not None and self._calculate_expires_at(renew_for) is not None: - update_sql = f"UPDATE {self._session_table} SET update_time = ? WHERE id = ?" - cursor.execute(update_sql, (datetime.now(timezone.utc), session_id)) + update_sql = f"UPDATE {self._session_table} SET update_time = ? WHERE app_name = ? AND user_id = ? AND id = ?" + cursor.execute(update_sql, (datetime.now(timezone.utc), app_name, user_id, session_id)) conn.commit() - cursor.execute(sql, (session_id,)) + cursor.execute(sql, (app_name, user_id, session_id)) row = cursor.fetchone() if row is None: @@ -767,10 +770,12 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No return None raise - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. state: New state dictionary (replaces existing state). @@ -782,33 +787,35 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non sql = f""" UPDATE {self._session_table} SET state = ?, update_time = CURRENT_TIMESTAMP - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, (state_json, session_id)) + cursor.execute(sql, (state_json, app_name, user_id, session_id)) conn.commit() finally: cursor.close() - def _delete_session(self, session_id: str) -> None: + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and all associated events (cascade). Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. Notes: Foreign key constraint ensures events are cascade-deleted. """ - sql = f"DELETE FROM {self._session_table} WHERE id = ?" + sql = f"DELETE FROM {self._session_table} WHERE app_name = ? AND user_id = ? AND id = ?" with self._config.provide_connection() as conn: cursor = conn.cursor() try: self._enable_foreign_keys(cursor, conn) - cursor.execute(sql, (session_id,)) + cursor.execute(sql, (app_name, user_id, session_id)) conn.commit() finally: cursor.close() @@ -878,7 +885,7 @@ def _insert_event(self, event_record: "EventRecord") -> None: event_data = self._serialize_json_field(event_record["event_data"]) sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (?, ?, ?, ?, ?) """ @@ -888,9 +895,9 @@ def _insert_event(self, event_record: "EventRecord") -> None: cursor.execute( sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data, ), @@ -902,11 +909,11 @@ def _insert_event(self, event_record: "EventRecord") -> None: def _append_event_and_update_state( self, event_record: "EventRecord", + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: @@ -919,18 +926,18 @@ def _append_event_and_update_state( """ insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (?, ?, ?, ?, ?) """ update_sql = f""" UPDATE {self._session_table} SET state = ?, update_time = CURRENT_TIMESTAMP - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ app_delete_sql = f"DELETE FROM {self._app_state_table} WHERE app_name = ?" app_insert_sql = f""" @@ -942,12 +949,6 @@ def _append_event_and_update_state( INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) VALUES (?, ?, ?, ?) """ - if app_state and app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) - if user_state and (app_name is None or user_id is None): - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) state_json = self._serialize_state(state) event_data = self._serialize_json_field(event_record["event_data"]) @@ -961,15 +962,15 @@ def _append_event_and_update_state( cursor.execute( insert_sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data, ), ) - cursor.execute(update_sql, (state_json, session_id)) - cursor.execute(select_sql, (session_id,)) + cursor.execute(update_sql, (state_json, app_name, user_id, session_id)) + cursor.execute(select_sql, (app_name, user_id, session_id)) row = cursor.fetchone() if app_state: cursor.execute(app_delete_sql, (app_name,)) @@ -999,11 +1000,18 @@ def _append_event_and_update_state( ) def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """List events for a session ordered by timestamp. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. after_timestamp: Only return events after this time. limit: Maximum number of events to return. @@ -1013,23 +1021,23 @@ def _get_events( Notes: Uses index on (session_id, timestamp ASC). - Returns the 5-column EventRecord (session_id, invocation_id, - author, timestamp, event_data). + Returns the EventRecord. """ - where_clauses = ["session_id = ?"] - params: list[Any] = [session_id] + where_clauses = ["s.app_name = ?", "s.user_id = ?", "e.session_id = ?"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > ?") + where_clauses.append("e.timestamp > ?") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -1041,11 +1049,13 @@ def _get_events( return [ EventRecord( - session_id=row[0], - invocation_id=row[1], - author=row[2], + id=row[0], + session_id=row[1], + invocation_id=row[2], timestamp=self._decode_timestamp(row[3]), event_data=self._deserialize_json_field(row[4]) or {}, + app_name=row[5], + user_id=row[6], ) for row in rows ] diff --git a/sqlspec/adapters/aiomysql/adk/store.py b/sqlspec/adapters/aiomysql/adk/store.py index 2baa964ef..0ff077e89 100644 --- a/sqlspec/adapters/aiomysql/adk/store.py +++ b/sqlspec/adapters/aiomysql/adk/store.py @@ -119,24 +119,20 @@ async def create_session( await cursor.execute(sql, params) await conn.commit() - return await self.get_session(session_id) # type: ignore[return-value] + result = await self.get_session(app_name, user_id, session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - renew_for: If positive, touch update_time while reading. - - Returns: - Session record or None if not found. - """ + """Get session by ID.""" sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ try: @@ -145,22 +141,22 @@ async def get_session( AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, ): if renew_for is not None and self._calculate_expires_at(renew_for) is not None: - update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE id = %s" - await cursor.execute(update_sql, (session_id,)) + update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE app_name = %s AND user_id = %s AND id = %s" + await cursor.execute(update_sql, (app_name, user_id, session_id)) await conn.commit() - await cursor.execute(sql, (session_id,)) + await cursor.execute(sql, (app_name, user_id, session_id)) row = await cursor.fetchone() if row is None: return None - session_id_val, app_name, user_id, state_json, create_time, update_time = row + session_id_val, app_name_val, user_id_val, state_json, create_time, update_time = row return SessionRecord( id=session_id_val, - app_name=app_name, - user_id=user_id, + app_name=app_name_val, + user_id=user_id_val, state=from_json(state_json) if isinstance(state_json, str) else state_json, create_time=create_time, update_time=update_time, @@ -170,53 +166,36 @@ async def get_session( return None raise - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - """ + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" state_json = to_json(state) sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ async with ( self._config.provide_connection() as conn, AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, ): - await cursor.execute(sql, (state_json, session_id)) + await cursor.execute(sql, (state_json, app_name, user_id, session_id)) await conn.commit() - async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. - """ - sql = f"DELETE FROM {self._session_table} WHERE id = %s" + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and all associated events (cascade).""" + sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" async with ( self._config.provide_connection() as conn, AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, ): - await cursor.execute(sql, (session_id,)) + await cursor.execute(sql, (app_name, user_id, session_id)) await conn.commit() async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - """ + """List sessions for an app, optionally filtered by user.""" if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -259,18 +238,13 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis raise async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record with 5 keys (session_id, invocation_id, - author, timestamp, event_data). - """ + """Append an event to a session.""" event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -281,9 +255,9 @@ async def append_event(self, event_record: EventRecord) -> None: await cursor.execute( sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data_str, ), @@ -293,40 +267,35 @@ async def append_event(self, event_record: EventRecord) -> None: async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update session + scoped state. - - MySQL doesn't support UPDATE...RETURNING; we follow the UPDATE with a - SELECT inside the same transaction so callers get the refreshed row - in a single round-trip pair (no separate connection acquisition). - """ + """Atomically append an event and update session + scoped state.""" event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data state_json = to_json(state) insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ update_sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ app_upsert_sql = f""" @@ -348,25 +317,19 @@ async def append_event_and_update_state( await cursor.execute( insert_sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data_str, ), ) - await cursor.execute(update_sql, (state_json, session_id)) - await cursor.execute(select_sql, (session_id,)) + await cursor.execute(update_sql, (state_json, app_name, user_id, session_id)) + await cursor.execute(select_sql, (app_name, user_id, session_id)) row = await cursor.fetchone() if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) await cursor.execute(app_upsert_sql, (app_name, to_json(app_state))) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) await cursor.execute(user_upsert_sql, (app_name, user_id, to_json(user_state))) await conn.commit() @@ -385,33 +348,30 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - """ - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + """Get events for a session.""" + where_clauses = ["s.app_name = %s", "s.user_id = %s", "e.session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > %s") + where_clauses.append("e.timestamp > %s") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -424,11 +384,13 @@ async def get_events( return [ EventRecord( - session_id=row[0], - invocation_id=row[1], - author=row[2], + id=row[0], + session_id=row[1], + invocation_id=row[2], timestamp=row[3], event_data=from_json(row[4]) if isinstance(row[4], str) else row[4], + app_name=row[5], + user_id=row[6], ) for row in rows ] @@ -624,18 +586,12 @@ async def _get_create_events_table_sql(self) -> str: Returns: SQL statement to create adk_event table with indexes. - - Notes: - Post clean-break schema: 5 columns only. - - session_id, invocation_id, author: indexed scalars - - timestamp: microsecond-precision TIMESTAMP - - event_data: full Event as native JSON """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, diff --git a/sqlspec/adapters/aiosqlite/adk/store.py b/sqlspec/adapters/aiosqlite/adk/store.py index 82179e24d..7cbd3ebb0 100644 --- a/sqlspec/adapters/aiosqlite/adk/store.py +++ b/sqlspec/adapters/aiosqlite/adk/store.py @@ -169,11 +169,13 @@ async def create_session( ) async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": - """Get session by ID. + """Get session. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. renew_for: If positive, touch update_time while reading. @@ -187,18 +189,20 @@ async def get_session( sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ try: async with self._config.provide_connection() as conn: await self._apply_pragmas(conn) if renew_for is not None and self._calculate_expires_at(renew_for) is not None: - update_sql = f"UPDATE {self._session_table} SET update_time = ? WHERE id = ?" - await conn.execute(update_sql, (_datetime_to_julian(datetime.now(timezone.utc)), session_id)) + update_sql = f"UPDATE {self._session_table} SET update_time = ? WHERE app_name = ? AND user_id = ? AND id = ?" + await conn.execute( + update_sql, (_datetime_to_julian(datetime.now(timezone.utc)), app_name, user_id, session_id) + ) await conn.commit() - cursor = await conn.execute(sql, (session_id,)) + cursor = await conn.execute(sql, (app_name, user_id, session_id)) row = await cursor.fetchone() if row is None: @@ -217,10 +221,12 @@ async def get_session( return None raise - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. state: New state dictionary (replaces existing state). @@ -235,12 +241,12 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - sql = f""" UPDATE {self._session_table} SET state = ?, update_time = ? - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ async with self._config.provide_connection() as conn: await self._apply_pragmas(conn) - await conn.execute(sql, (state_json, now_julian, session_id)) + await conn.execute(sql, (state_json, now_julian, app_name, user_id, session_id)) await conn.commit() async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": @@ -295,43 +301,41 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis return [] raise - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and all associated events (cascade). Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. Notes: Foreign key constraint ensures events are cascade-deleted. """ - sql = f"DELETE FROM {self._session_table} WHERE id = ?" + sql = f"DELETE FROM {self._session_table} WHERE app_name = ? AND user_id = ? AND id = ?" async with self._config.provide_connection() as conn: await self._apply_pragmas(conn) - await conn.execute(sql, (session_id,)) + await conn.execute(sql, (app_name, user_id, session_id)) await conn.commit() async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session. Args: - event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_data. + event_record: Event record. Notes: Uses Julian Day for timestamp. event_data dict is serialized to TEXT as event_data column. """ - import uuid - timestamp_julian = _datetime_to_julian(event_record["timestamp"]) event_data_json = to_json(event_record["event_data"]) - event_id = str(uuid.uuid4()) sql = f""" INSERT INTO {self._events_table} ( - id, session_id, invocation_id, author, timestamp, event_data - ) VALUES (?, ?, ?, ?, ?, ?) + id, session_id, invocation_id, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?) """ async with self._config.provide_connection() as conn: @@ -339,10 +343,9 @@ async def append_event(self, event_record: EventRecord) -> None: await conn.execute( sql, ( - event_id, + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], timestamp_julian, event_data_json, ), @@ -352,33 +355,30 @@ async def append_event(self, event_record: EventRecord) -> None: async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically append an event and update session + scoped state.""" - import uuid - timestamp_julian = _datetime_to_julian(event_record["timestamp"]) event_data_json = to_json(event_record["event_data"]) now_julian = _datetime_to_julian(datetime.now(timezone.utc)) state_json = to_json(state) - event_id = str(uuid.uuid4()) insert_sql = f""" INSERT INTO {self._events_table} ( - id, session_id, invocation_id, author, timestamp, event_data - ) VALUES (?, ?, ?, ?, ?, ?) + id, session_id, invocation_id, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?) """ update_sql = f""" UPDATE {self._session_table} SET state = ?, update_time = ? - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? RETURNING id, app_name, user_id, state, create_time, update_time """ @@ -403,25 +403,18 @@ async def append_event_and_update_state( await conn.execute( insert_sql, ( - event_id, + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], timestamp_julian, event_data_json, ), ) - cursor = await conn.execute(update_sql, (state_json, now_julian, session_id)) + cursor = await conn.execute(update_sql, (state_json, now_julian, app_name, user_id, session_id)) row = await cursor.fetchone() if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) await conn.execute(app_upsert_sql, (app_name, to_json(app_state), now_julian)) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) await conn.execute(user_upsert_sql, (app_name, user_id, to_json(user_state), now_julian)) await conn.commit() @@ -439,11 +432,18 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. after_timestamp: Only return events after this time. limit: Maximum number of events to return. @@ -455,21 +455,22 @@ async def get_events( Uses index on (session_id, timestamp ASC). Parses event_data TEXT back to dict for event_data field. """ - where_clauses = ["session_id = ?"] - params: list[Any] = [session_id] + where_clauses = ["s.app_name = ?", "s.user_id = ?", "e.session_id = ?"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > ?") + where_clauses.append("e.timestamp > ?") params.append(_datetime_to_julian(after_timestamp)) where_clause = " AND ".join(where_clauses) limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT id, session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -480,11 +481,13 @@ async def get_events( return [ EventRecord( + id=row[0], session_id=row[1], invocation_id=row[2], - author=row[3], - timestamp=_julian_to_datetime(row[4]), - event_data=from_json(row[5]) if row[5] else {}, + timestamp=_julian_to_datetime(row[3]), + event_data=from_json(row[4]) if row[4] else {}, + app_name=row[5], + user_id=row[6], ) for row in rows ] @@ -669,7 +672,6 @@ async def _get_create_events_table_sql(self) -> str: id TEXT PRIMARY KEY, session_id TEXT NOT NULL, invocation_id TEXT, - author TEXT, timestamp REAL NOT NULL, event_data TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE diff --git a/sqlspec/adapters/asyncmy/adk/store.py b/sqlspec/adapters/asyncmy/adk/store.py index 48e263819..9ac89afe6 100644 --- a/sqlspec/adapters/asyncmy/adk/store.py +++ b/sqlspec/adapters/asyncmy/adk/store.py @@ -115,95 +115,74 @@ async def create_session( await cursor.execute(sql, params) await conn.commit() - return await self.get_session(session_id) # type: ignore[return-value] + result = await self.get_session(app_name, user_id, session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - renew_for: If positive, touch update_time while reading. - - Returns: - Session record or None if not found. - """ + """Get session by ID.""" sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ try: async with self._config.provide_connection() as conn, conn.cursor() as cursor: if renew_for is not None and self._calculate_expires_at(renew_for) is not None: - update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE id = %s" - await cursor.execute(update_sql, (session_id,)) + update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE app_name = %s AND user_id = %s AND id = %s" + await cursor.execute(update_sql, (app_name, user_id, session_id)) await conn.commit() - await cursor.execute(sql, (session_id,)) + await cursor.execute(sql, (app_name, user_id, session_id)) row = await cursor.fetchone() if row is None: return None - session_id_val, app_name, user_id, state_json, create_time, update_time = row + session_id_val, app_name_val, user_id_val, state_json, create_time, update_time = row return SessionRecord( id=session_id_val, - app_name=app_name, - user_id=user_id, + app_name=app_name_val, + user_id=user_id_val, state=from_json(state_json) if isinstance(state_json, str) else state_json, create_time=create_time, update_time=update_time, ) - except asyncmy.errors.ProgrammingError as e: # pyright: ignore[reportAttributeAccessIssue][reportAttributeAccessIssue] + except asyncmy.errors.ProgrammingError as e: # pyright: ignore[reportAttributeAccessIssue] if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: return None raise - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - """ + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" state_json = to_json(state) sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ async with self._config.provide_connection() as conn, conn.cursor() as cursor: - await cursor.execute(sql, (state_json, session_id)) + await cursor.execute(sql, (state_json, app_name, user_id, session_id)) await conn.commit() - async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. - """ - sql = f"DELETE FROM {self._session_table} WHERE id = %s" + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and all associated events (cascade).""" + sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" async with self._config.provide_connection() as conn, conn.cursor() as cursor: - await cursor.execute(sql, (session_id,)) + await cursor.execute(sql, (app_name, user_id, session_id)) await conn.commit() async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - """ + """List sessions for an app, optionally filtered by user.""" if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -243,18 +222,13 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis raise async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record with 5 keys (session_id, invocation_id, - author, timestamp, event_data). - """ + """Append an event to a session.""" event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -262,9 +236,9 @@ async def append_event(self, event_record: EventRecord) -> None: await cursor.execute( sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data_str, ), @@ -274,40 +248,35 @@ async def append_event(self, event_record: EventRecord) -> None: async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update session + scoped state. - - MySQL doesn't support UPDATE...RETURNING; we follow the UPDATE with a - SELECT inside the same transaction so callers get the refreshed row - in a single round-trip pair (no separate connection acquisition). - """ + """Atomically append an event and update session + scoped state.""" event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data state_json = to_json(state) insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ update_sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ app_upsert_sql = f""" @@ -326,25 +295,19 @@ async def append_event_and_update_state( await cursor.execute( insert_sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data_str, ), ) - await cursor.execute(update_sql, (state_json, session_id)) - await cursor.execute(select_sql, (session_id,)) + await cursor.execute(update_sql, (state_json, app_name, user_id, session_id)) + await cursor.execute(select_sql, (app_name, user_id, session_id)) row = await cursor.fetchone() if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) await cursor.execute(app_upsert_sql, (app_name, to_json(app_state))) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) await cursor.execute(user_upsert_sql, (app_name, user_id, to_json(user_state))) await conn.commit() @@ -363,33 +326,30 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - """ - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + """Get events for a session.""" + where_clauses = ["s.app_name = %s", "s.user_id = %s", "e.session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > %s") + where_clauses.append("e.timestamp > %s") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -399,11 +359,13 @@ async def get_events( return [ EventRecord( - session_id=row[0], - invocation_id=row[1], - author=row[2], + id=row[0], + session_id=row[1], + invocation_id=row[2], timestamp=row[3], event_data=from_json(row[4]) if isinstance(row[4], str) else row[4], + app_name=row[5], + user_id=row[6], ) for row in rows ] @@ -575,18 +537,12 @@ async def _get_create_events_table_sql(self) -> str: Returns: SQL statement to create adk_event table with indexes. - - Notes: - Post clean-break schema: 5 columns only. - - session_id, invocation_id, author: indexed scalars - - timestamp: microsecond-precision TIMESTAMP - - event_data: full Event as native JSON """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, diff --git a/sqlspec/adapters/asyncpg/adk/store.py b/sqlspec/adapters/asyncpg/adk/store.py index 72c62382d..d9d9d5082 100644 --- a/sqlspec/adapters/asyncpg/adk/store.py +++ b/sqlspec/adapters/asyncpg/adk/store.py @@ -73,28 +73,34 @@ async def create_session( """ await conn.execute(sql, session_id, app_name, user_id, state) - return await self.get_session(session_id) # type: ignore[return-value] + result = await self.get_session(app_name, user_id, session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": if renew_for is not None and self._calculate_expires_at(renew_for) is not None: sql = f""" UPDATE {self._session_table} SET update_time = CURRENT_TIMESTAMP - WHERE id = $1 + WHERE app_name = $1 AND user_id = $2 AND id = $3 RETURNING id, app_name, user_id, state, create_time, update_time """ + params = [app_name, user_id, session_id] else: sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = $1 - """ + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = $1 AND user_id = $2 AND id = $3 + """ + params = [app_name, user_id, session_id] try: async with self._config.provide_connection() as conn: - row = await conn.fetchrow(sql, session_id) + row = await conn.fetchrow(sql, *params) if row is None: return None @@ -110,21 +116,21 @@ async def get_session( except asyncpg.exceptions.UndefinedTableError: return None - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP - WHERE id = $2 + WHERE app_name = $2 AND user_id = $3 AND id = $4 """ async with self._config.provide_connection() as conn: - await conn.execute(sql, state, session_id) + await conn.execute(sql, state, app_name, user_id, session_id) - async def delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = $1" + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = $1 AND user_id = $2 AND id = $3" async with self._config.provide_connection() as conn: - await conn.execute(sql, session_id) + await conn.execute(sql, app_name, user_id, session_id) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: @@ -165,16 +171,16 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ async with self._config.provide_connection() as conn: await conn.execute( sql, + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_record["event_data"], ) @@ -182,23 +188,23 @@ async def append_event(self, event_record: EventRecord) -> None: async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ update_sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP - WHERE id = $2 + WHERE app_name = $2 AND user_id = $3 AND id = $4 RETURNING id, app_name, user_id, state, create_time, update_time """ app_upsert_sql = f""" @@ -219,22 +225,16 @@ async def append_event_and_update_state( async with self._config.provide_connection() as conn, conn.transaction(): await conn.execute( insert_sql, + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_record["event_data"], ) - row = await conn.fetchrow(update_sql, state, session_id) + row = await conn.fetchrow(update_sql, state, app_name, user_id, session_id) if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) await conn.execute(app_upsert_sql, app_name, app_state) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) await conn.execute(user_upsert_sql, app_name, user_id, user_state) if row is None: @@ -251,13 +251,18 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = $1"] - params: list[Any] = [session_id] + where_clauses = ["s.app_name = $1", "s.user_id = $2", "e.session_id = $3"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append(f"timestamp > ${len(params) + 1}") + where_clauses.append(f"e.timestamp > ${len(params) + 1}") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) @@ -266,10 +271,11 @@ async def get_events( params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -278,11 +284,13 @@ async def get_events( return [ EventRecord( + id=row["id"], session_id=row["session_id"], invocation_id=row["invocation_id"], - author=row["author"], timestamp=row["timestamp"], event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], ) for row in rows ] @@ -402,9 +410,9 @@ async def _get_create_sessions_table_sql(self) -> str: async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE diff --git a/sqlspec/adapters/bigquery/adk/store.py b/sqlspec/adapters/bigquery/adk/store.py index 1c6cb4abf..eae9e5853 100644 --- a/sqlspec/adapters/bigquery/adk/store.py +++ b/sqlspec/adapters/bigquery/adk/store.py @@ -83,18 +83,18 @@ async def create_session( return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": - return await async_(self._get_session)(session_id, renew_for) + return await async_(self._get_session)(app_name, user_id, session_id, renew_for) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - await async_(self._update_session_state)(session_id, state) + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + await async_(self._update_session_state)(app_name, user_id, session_id, state) async def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": return await async_(self._list_sessions)(app_name, user_id) - async def delete_session(self, session_id: str) -> None: - await async_(self._delete_session)(session_id) + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + await async_(self._delete_session)(app_name, user_id, session_id) async def append_event(self, event_record: EventRecord) -> None: await async_(self._append_event)(event_record) @@ -102,28 +102,27 @@ async def append_event(self, event_record: EventRecord) -> None: async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: return await async_(self._append_event_and_update_state)( - event_record, - session_id, - state, - app_name=app_name, - user_id=user_id, - app_state=app_state, - user_state=user_state, + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - return await async_(self._get_events)(session_id, after_timestamp, limit) + return await async_(self._get_events)(app_name, user_id, session_id, after_timestamp, limit) async def delete_expired_events(self, before: datetime) -> int: return await async_(self._delete_expired_events)(before) @@ -201,17 +200,26 @@ def _create_session( "update_time": now, } - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": if renew_for is not None and self._calculate_expires_at(renew_for) is not None: - self._update_session_touch(session_id) + self._update_session_touch(app_name, user_id, session_id) sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._qualified(self._session_table)} - WHERE id = @id + WHERE app_name = @app_name AND user_id = @user_id AND id = @id LIMIT 1 """ - rows = self._run_query(sql, [self._query_param("id", session_id)]) + rows = self._run_query( + sql, + [ + self._query_param("app_name", app_name), + self._query_param("user_id", user_id), + self._query_param("id", session_id), + ], + ) if not rows: return None row = rows[0] @@ -225,21 +233,36 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No } return record - def _update_session_touch(self, session_id: str) -> None: + def _update_session_touch(self, app_name: str, user_id: str, session_id: str) -> None: sql = f""" UPDATE {self._qualified(self._session_table)} SET update_time = CURRENT_TIMESTAMP() - WHERE id = @id + WHERE app_name = @app_name AND user_id = @user_id AND id = @id """ - self._run_query(sql, [self._query_param("id", session_id)]) + self._run_query( + sql, + [ + self._query_param("app_name", app_name), + self._query_param("user_id", user_id), + self._query_param("id", session_id), + ], + ) - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._qualified(self._session_table)} SET state = @state, update_time = CURRENT_TIMESTAMP() - WHERE id = @id + WHERE app_name = @app_name AND user_id = @user_id AND id = @id """ - self._run_query(sql, [self._json_param("state", state), self._query_param("id", session_id)]) + self._run_query( + sql, + [ + self._json_param("state", state), + self._query_param("app_name", app_name), + self._query_param("user_id", user_id), + self._query_param("id", session_id), + ], + ) def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": window_start = datetime.now(timezone.utc) - timedelta(days=self._lookup_window_days) @@ -271,22 +294,29 @@ def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[S records.append(record) return records - def _delete_session(self, session_id: str) -> None: + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: events_sql = f"DELETE FROM {self._qualified(self._events_table)} WHERE session_id = @id" - sessions_sql = f"DELETE FROM {self._qualified(self._session_table)} WHERE id = @id" + sessions_sql = f"DELETE FROM {self._qualified(self._session_table)} WHERE app_name = @app_name AND user_id = @user_id AND id = @id" self._run_query(events_sql, [self._query_param("id", session_id)]) - self._run_query(sessions_sql, [self._query_param("id", session_id)]) + self._run_query( + sessions_sql, + [ + self._query_param("app_name", app_name), + self._query_param("user_id", user_id), + self._query_param("id", session_id), + ], + ) def _append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._qualified(self._events_table)} - (session_id, invocation_id, author, timestamp, event_data) - VALUES (@session_id, @invocation_id, @author, @timestamp, @event_data) + (id, session_id, invocation_id, timestamp, event_data) + VALUES (@id, @session_id, @invocation_id, @timestamp, @event_data) """ params = [ + self._query_param("id", event_record["id"]), self._query_param("session_id", event_record["session_id"]), self._query_param("invocation_id", event_record["invocation_id"]), - self._query_param("author", event_record["author"]), self._query_param("timestamp", event_record["timestamp"], bq_type="TIMESTAMP"), self._json_param("event_data", event_record["event_data"]), ] @@ -295,61 +325,63 @@ def _append_event(self, event_record: EventRecord) -> None: def _append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - if app_state and app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) - if user_state and (app_name is None or user_id is None): - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) - - # BigQuery DML statements are not transactional across separate jobs. We - # accept this trade-off because the BigQuery ADK store is positioned as - # the analytics-replica path, not a live OLTP store. self._append_event(event_record) - self._update_session_state(session_id, state) + self._update_session_state(app_name, user_id, session_id, state) if app_state: - self._upsert_app_state(cast("str", app_name), app_state) + self._upsert_app_state(app_name, app_state) if user_state: - self._upsert_user_state(cast("str", app_name), cast("str", user_id), user_state) + self._upsert_user_state(app_name, user_id, user_state) - record = self._get_session(session_id) + record = self._get_session(app_name, user_id, session_id) if record is None: msg = f"Session {session_id} not found during append_event_and_update_state." raise ValueError(msg) return record def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._qualified(self._events_table)} - WHERE session_id = @session_id + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._qualified(self._events_table)} e + JOIN {self._qualified(self._session_table)} s ON e.session_id = s.id + WHERE s.app_name = @app_name AND s.user_id = @user_id AND e.session_id = @session_id """ - params = [self._query_param("session_id", session_id)] + params = [ + self._query_param("app_name", app_name), + self._query_param("user_id", user_id), + self._query_param("session_id", session_id), + ] if after_timestamp is not None: - sql += " AND timestamp > @after_timestamp" + sql += " AND e.timestamp > @after_timestamp" params.append(self._query_param("after_timestamp", after_timestamp, bq_type="TIMESTAMP")) - sql += " ORDER BY timestamp ASC" + sql += " ORDER BY e.timestamp ASC" if limit is not None: sql += " LIMIT @row_limit" params.append(self._query_param("row_limit", limit, bq_type="INT64")) rows = self._run_query(sql, params) return [ { + "id": row["id"], "session_id": row["session_id"], "invocation_id": row["invocation_id"], - "author": row["author"], "timestamp": row["timestamp"], "event_data": self._decode_json(row["event_data"]) or {}, + "app_name": row["app_name"], + "user_id": row["user_id"], } for row in rows ] @@ -463,15 +495,15 @@ async def _get_create_sessions_table_sql(self) -> str: async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._qualified(self._events_table)} ( + id STRING NOT NULL, session_id STRING NOT NULL, - invocation_id STRING NOT NULL, - author STRING NOT NULL, + invocation_id STRING, timestamp TIMESTAMP NOT NULL, event_data JSON ) PARTITION BY DATE(timestamp) - CLUSTER BY session_id, app_name_cluster_placeholder, user_id_cluster_placeholder{self._partition_options()} - """.replace(", app_name_cluster_placeholder, user_id_cluster_placeholder", "") + CLUSTER BY session_id, id{self._partition_options()} + """ async def _get_create_app_states_table_sql(self) -> str: return f""" diff --git a/sqlspec/adapters/cockroach_asyncpg/adk/store.py b/sqlspec/adapters/cockroach_asyncpg/adk/store.py index 5d6d1c2bd..a7caba06a 100644 --- a/sqlspec/adapters/cockroach_asyncpg/adk/store.py +++ b/sqlspec/adapters/cockroach_asyncpg/adk/store.py @@ -69,32 +69,34 @@ async def create_session( async with self._config.provide_connection() as conn: await conn.execute(sql, *params) - result = await self.get_session(session_id) + result = await self.get_session(app_name, user_id, session_id) if result is None: msg = "Session creation failed" raise RuntimeError(msg) return result async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": if renew_for is not None and self._calculate_expires_at(renew_for) is not None: sql = f""" UPDATE {self._session_table} SET update_time = CURRENT_TIMESTAMP - WHERE id = $1 + WHERE app_name = $1 AND user_id = $2 AND id = $3 RETURNING id, app_name, user_id, state, create_time, update_time """ + params = (app_name, user_id, session_id) else: sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = $1 - """ + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = $1 AND user_id = $2 AND id = $3 + """ + params = (app_name, user_id, session_id) try: async with self._config.provide_connection() as conn: - row = await conn.fetchrow(sql, session_id) + row = await conn.fetchrow(sql, *params) if row is None: return None @@ -109,21 +111,21 @@ async def get_session( except asyncpg.exceptions.UndefinedTableError: return None - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP - WHERE id = $2 + WHERE app_name = $2 AND user_id = $3 AND id = $4 """ async with self._config.provide_connection() as conn: - await conn.execute(sql, state, session_id) + await conn.execute(sql, state, app_name, user_id, session_id) - async def delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = $1" + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = $1 AND user_id = $2 AND id = $3" async with self._config.provide_connection() as conn: - await conn.execute(sql, session_id) + await conn.execute(sql, app_name, user_id, session_id) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: @@ -164,16 +166,16 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ async with self._config.provide_connection() as conn: await conn.execute( sql, + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_record["event_data"], ) @@ -181,23 +183,23 @@ async def append_event(self, event_record: EventRecord) -> None: async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ update_sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP - WHERE id = $2 + WHERE app_name = $2 AND user_id = $3 AND id = $4 RETURNING id, app_name, user_id, state, create_time, update_time """ app_upsert_sql = f""" @@ -218,22 +220,16 @@ async def append_event_and_update_state( async with self._config.provide_connection() as conn, conn.transaction(): await conn.execute( insert_sql, + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_record["event_data"], ) - row = await conn.fetchrow(update_sql, state, session_id) + row = await conn.fetchrow(update_sql, state, app_name, user_id, session_id) if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) await conn.execute(app_upsert_sql, app_name, app_state) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) await conn.execute(user_upsert_sql, app_name, user_id, user_state) if row is None: @@ -250,13 +246,18 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = $1"] - params: list[Any] = [session_id] + where_clauses = ["s.app_name = $1", "s.user_id = $2", "e.session_id = $3"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append(f"timestamp > ${len(params) + 1}") + where_clauses.append(f"e.timestamp > ${len(params) + 1}") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) @@ -265,10 +266,11 @@ async def get_events( params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -279,11 +281,13 @@ async def get_events( return [ EventRecord( + id=row["id"], session_id=row["session_id"], invocation_id=row["invocation_id"], - author=row["author"], timestamp=row["timestamp"], event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], ) for row in rows ] @@ -394,9 +398,9 @@ async def _get_create_sessions_table_sql(self) -> str: async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py index 9980611a7..7038f51eb 100644 --- a/sqlspec/adapters/cockroach_psycopg/adk/store.py +++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py @@ -113,32 +113,34 @@ async def create_session( await cur.execute(sql.encode(), params) await conn.commit() - result = await self.get_session(session_id) + result = await self.get_session(app_name, user_id, session_id) if result is None: msg = "Session creation failed" raise RuntimeError(msg) return result async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": if renew_for is not None and self._calculate_expires_at(renew_for) is not None: sql = f""" UPDATE {self._session_table} SET update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s RETURNING id, app_name, user_id, state, create_time, update_time """ + params = (app_name, user_id, session_id) else: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ + params = (app_name, user_id, session_id) try: async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql.encode(), (session_id,)) + await cur.execute(sql.encode(), params) row = await cur.fetchone() if row is None: @@ -155,22 +157,22 @@ async def get_session( except errors.UndefinedTable: return None - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql.encode(), (Jsonb(state), session_id)) + await cur.execute(sql.encode(), (Jsonb(state), app_name, user_id, session_id)) await conn.commit() - async def delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = %s" + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql.encode(), (session_id,)) + await cur.execute(sql.encode(), (app_name, user_id, session_id)) await conn.commit() async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": @@ -213,7 +215,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -224,9 +226,9 @@ async def append_event(self, event_record: EventRecord) -> None: await cur.execute( sql.encode(), ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], jsonb_value, ), @@ -236,23 +238,23 @@ async def append_event(self, event_record: EventRecord) -> None: async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ update_sql = f""" UPDATE {self._session_table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s RETURNING id, app_name, user_id, state, create_time, update_time """ app_upsert_sql = f""" @@ -277,24 +279,18 @@ async def append_event_and_update_state( await cur.execute( insert_sql.encode(), ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], jsonb_value, ), ) - await cur.execute(update_sql.encode(), (Jsonb(state), session_id)) + await cur.execute(update_sql.encode(), (Jsonb(state), app_name, user_id, session_id)) row = await cur.fetchone() if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) await cur.execute(app_upsert_sql.encode(), (app_name, Jsonb(app_state))) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) await cur.execute(user_upsert_sql.encode(), (app_name, user_id, Jsonb(user_state))) await conn.commit() @@ -312,26 +308,31 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + where_clauses = ["s.app_name = %s", "s.user_id = %s", "e.session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > %s") + where_clauses.append("e.timestamp > %s") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) limit_clause = " LIMIT %s" if limit else "" - if limit: - params.append(limit) - sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ + if limit: + params.append(limit) try: async with self._config.provide_connection() as conn, conn.cursor() as cur: @@ -340,11 +341,13 @@ async def get_events( return [ EventRecord( + id=row["id"], session_id=row["session_id"], invocation_id=row["invocation_id"], - author=row["author"], timestamp=row["timestamp"], event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], ) for row in rows ] @@ -465,9 +468,9 @@ async def _get_create_sessions_table_sql(self) -> str: async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE @@ -564,18 +567,18 @@ async def create_session( return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) + return await async_(self._get_session)(app_name, user_id, session_id, renew_for) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" - await async_(self._update_session_state)(session_id, state) + await async_(self._update_session_state)(app_name, user_id, session_id, state) - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and associated events.""" - await async_(self._delete_session)(session_id) + await async_(self._delete_session)(app_name, user_id, session_id) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" @@ -584,30 +587,29 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically append an event and update session + scoped state.""" return await async_(self._append_event_and_update_state)( - event_record, - session_id, - state, - app_name=app_name, - user_id=user_id, - app_state=app_state, - user_state=user_state, + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) + return await async_(self._get_events)(app_name, user_id, session_id, after_timestamp, limit) async def delete_expired_events(self, before: "datetime") -> int: """Delete events older than the given timestamp.""" @@ -674,9 +676,9 @@ async def _get_create_sessions_table_sql(self) -> str: async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE @@ -774,30 +776,34 @@ def _create_session( cur.execute(sql.encode(), params) conn.commit() - result = self._get_session(session_id) + result = self._get_session(app_name, user_id, session_id) if result is None: msg = "Session creation failed" raise RuntimeError(msg) return result - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": if renew_for is not None and self._calculate_expires_at(renew_for) is not None: sql = f""" UPDATE {self._session_table} SET update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s RETURNING id, app_name, user_id, state, create_time, update_time """ + params = (app_name, user_id, session_id) else: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ + params = (app_name, user_id, session_id) try: with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(sql.encode(), (session_id,)) + cur.execute(sql.encode(), params) row = cur.fetchone() if row is None: @@ -814,22 +820,22 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No except errors.UndefinedTable: return None - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(sql.encode(), (Jsonb(state), session_id)) + cur.execute(sql.encode(), (Jsonb(state), app_name, user_id, session_id)) conn.commit() - def _delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = %s" + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(sql.encode(), (session_id,)) + cur.execute(sql.encode(), (app_name, user_id, session_id)) conn.commit() def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": @@ -872,23 +878,23 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses def _append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ update_sql = f""" UPDATE {self._session_table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s RETURNING id, app_name, user_id, state, create_time, update_time """ app_upsert_sql = f""" @@ -913,24 +919,18 @@ def _append_event_and_update_state( cur.execute( insert_sql.encode(), ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], jsonb_value, ), ) - cur.execute(update_sql.encode(), (Jsonb(state), session_id)) + cur.execute(update_sql.encode(), (Jsonb(state), app_name, user_id, session_id)) row = cur.fetchone() if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) cur.execute(app_upsert_sql.encode(), (app_name, Jsonb(app_state))) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) cur.execute(user_upsert_sql.encode(), (app_name, user_id, Jsonb(user_state))) conn.commit() @@ -950,7 +950,7 @@ def _append_event_and_update_state( def _insert_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -961,9 +961,9 @@ def _insert_event(self, event_record: EventRecord) -> None: cur.execute( sql.encode(), ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], jsonb_value, ), @@ -971,22 +971,28 @@ def _insert_event(self, event_record: EventRecord) -> None: conn.commit() def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + where_clauses = ["s.app_name = %s", "s.user_id = %s", "e.session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > %s") + where_clauses.append("e.timestamp > %s") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) limit_clause = " LIMIT %s" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ if limit: params.append(limit) @@ -998,11 +1004,13 @@ def _get_events( return [ EventRecord( + id=row["id"], session_id=row["session_id"], invocation_id=row["invocation_id"], - author=row["author"], timestamp=row["timestamp"], event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], ) for row in rows ] diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py index 2fd2f1d3e..63c750426 100644 --- a/sqlspec/adapters/duckdb/adk/store.py +++ b/sqlspec/adapters/duckdb/adk/store.py @@ -118,11 +118,13 @@ async def create_session( return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": - """Get session by ID. + """Get session. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. renew_for: If positive, touch update_time while reading. @@ -133,12 +135,14 @@ async def get_session( DuckDB returns datetime objects for TIMESTAMPTZ columns. JSON is parsed from database storage. """ - return await async_(self._get_session)(session_id, renew_for) + return await async_(self._get_session)(app_name, user_id, session_id, renew_for) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. state: New state dictionary (replaces existing state). @@ -146,18 +150,20 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - This replaces the entire state dictionary. Update time is automatically set to current UTC timestamp. """ - await async_(self._update_session_state)(session_id, state) + await async_(self._update_session_state)(app_name, user_id, session_id, state) - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and all associated events. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. Notes: DuckDB doesn't support CASCADE in foreign keys, so we manually delete events first. """ - await async_(self._delete_session)(session_id) + await async_(self._delete_session)(app_name, user_id, session_id) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app, optionally filtered by user. @@ -178,39 +184,39 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session. Args: - event_record: Event record with 5 keys (session_id, invocation_id, - author, timestamp, event_data). + event_record: Event record. """ await async_(self._append_event)(event_record) async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically append an event and update session + scoped state.""" return await async_(self._append_event_and_update_state)( - event_record, - session_id, - state, - app_name=app_name, - user_id=user_id, - app_state=app_state, - user_state=user_state, + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. after_timestamp: Only return events after this time. limit: Maximum number of events to return. @@ -218,7 +224,7 @@ async def get_events( Returns: List of event records ordered by timestamp ASC. """ - return await async_(self._get_events)(session_id, after_timestamp, limit) + return await async_(self._get_events)(app_name, user_id, session_id, after_timestamp, limit) async def delete_expired_events(self, before: "datetime") -> int: """Delete events older than the given timestamp.""" @@ -291,7 +297,7 @@ async def _get_create_events_table_sql(self) -> str: SQL statement to create adk_event table with indexes. Notes: - - 5-column schema: session_id, invocation_id, author, timestamp, event_data + - 5-column schema: id, session_id, invocation_id, timestamp, event_data - event_data stores the full ADK Event as a single JSON blob - No decomposed columns -- eliminates column drift with upstream ADK - Foreign key constraint (DuckDB doesn't support CASCADE) @@ -300,9 +306,9 @@ async def _get_create_events_table_sql(self) -> str: """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR PRIMARY KEY, session_id VARCHAR NOT NULL, invocation_id VARCHAR NOT NULL, - author VARCHAR NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) @@ -389,9 +395,9 @@ def __get_create_events_table_sql_sync(self) -> str: """Synchronous version of DDL generation for use in _create_tables.""" return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR PRIMARY KEY, session_id VARCHAR NOT NULL, invocation_id VARCHAR NOT NULL, - author VARCHAR NOT NULL, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) @@ -467,35 +473,37 @@ def _create_session( id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now ) - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Synchronous implementation of get_session.""" sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ try: with self._config.provide_connection() as conn: if renew_for is not None and self._calculate_expires_at(renew_for) is not None: - update_sql = f"UPDATE {self._session_table} SET update_time = ? WHERE id = ?" - conn.execute(update_sql, (datetime.now(timezone.utc), session_id)) + update_sql = f"UPDATE {self._session_table} SET update_time = ? WHERE app_name = ? AND user_id = ? AND id = ?" + conn.execute(update_sql, (datetime.now(timezone.utc), app_name, user_id, session_id)) conn.commit() - cursor = conn.execute(sql, (session_id,)) + cursor = conn.execute(sql, (app_name, user_id, session_id)) row = cursor.fetchone() if row is None: return None - session_id_val, app_name, user_id, state_data, create_time, update_time = row + session_id_val, row_app_name, row_user_id, state_data, create_time, update_time = row state = from_json(state_data) if state_data else {} return SessionRecord( id=session_id_val, - app_name=app_name, - user_id=user_id, + app_name=row_app_name, + user_id=row_user_id, state=state, create_time=create_time, update_time=update_time, @@ -505,7 +513,7 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No return None raise - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Synchronous implementation of update_session_state.""" now = datetime.now(timezone.utc) state_json = to_json(state) @@ -513,21 +521,27 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non sql = f""" UPDATE {self._session_table} SET state = ?, update_time = ? - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ with self._config.provide_connection() as conn: - conn.execute(sql, (state_json, now, session_id)) + conn.execute(sql, (state_json, now, app_name, user_id, session_id)) conn.commit() - def _delete_session(self, session_id: str) -> None: + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Synchronous implementation of delete_session.""" - delete_events_sql = f"DELETE FROM {self._events_table} WHERE session_id = ?" - delete_session_sql = f"DELETE FROM {self._session_table} WHERE id = ?" + delete_events_sql = f""" + DELETE FROM {self._events_table} + WHERE session_id IN ( + SELECT id FROM {self._session_table} + WHERE app_name = ? AND user_id = ? AND id = ? + ) + """ + delete_session_sql = f"DELETE FROM {self._session_table} WHERE app_name = ? AND user_id = ? AND id = ?" with self._config.provide_connection() as conn: - conn.execute(delete_events_sql, (session_id,)) - conn.execute(delete_session_sql, (session_id,)) + conn.execute(delete_events_sql, (app_name, user_id, session_id)) + conn.execute(delete_session_sql, (app_name, user_id, session_id)) conn.commit() def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": @@ -576,7 +590,7 @@ def _append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} - (session_id, invocation_id, author, timestamp, event_data) + (id, session_id, invocation_id, timestamp, event_data) VALUES (?, ?, ?, ?, ?) """ @@ -584,9 +598,9 @@ def _append_event(self, event_record: EventRecord) -> None: conn.execute( sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data_str, ), @@ -596,11 +610,11 @@ def _append_event(self, event_record: EventRecord) -> None: def _append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: @@ -611,14 +625,14 @@ def _append_event_and_update_state( insert_sql = f""" INSERT INTO {self._events_table} - (session_id, invocation_id, author, timestamp, event_data) + (id, session_id, invocation_id, timestamp, event_data) VALUES (?, ?, ?, ?, ?) """ update_sql = f""" UPDATE {self._session_table} SET state = ?, update_time = ? - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? RETURNING id, app_name, user_id, state, create_time, update_time """ @@ -639,28 +653,22 @@ def _append_event_and_update_state( """ with self._config.provide_connection() as conn: - cursor = conn.execute(update_sql, (state_json, now, session_id)) + cursor = conn.execute(update_sql, (state_json, now, app_name, user_id, session_id)) row = cursor.fetchone() if row is not None: conn.execute( insert_sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data_str, ), ) if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) conn.execute(app_upsert_sql, (app_name, to_json(app_state), now)) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) conn.execute(user_upsert_sql, (app_name, user_id, to_json(user_state), now)) conn.commit() @@ -679,24 +687,30 @@ def _append_event_and_update_state( ) def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Synchronous implementation of get_events.""" - where_clauses = ["session_id = ?"] - params: list[Any] = [session_id] + where_clauses = ["s.app_name = ?", "s.user_id = ?", "e.session_id = ?"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > ?") + where_clauses.append("e.timestamp > ?") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -706,11 +720,13 @@ def _get_events( return [ EventRecord( - session_id=row[0], - invocation_id=row[1], - author=row[2], + id=row[0], + session_id=row[1], + invocation_id=row[2], timestamp=row[3], event_data=from_json(row[4]) if isinstance(row[4], str) else row[4], + app_name=row[5], + user_id=row[6], ) for row in rows ] diff --git a/sqlspec/adapters/mysqlconnector/adk/store.py b/sqlspec/adapters/mysqlconnector/adk/store.py index 9b42ce861..ddf5ab86b 100644 --- a/sqlspec/adapters/mysqlconnector/adk/store.py +++ b/sqlspec/adapters/mysqlconnector/adk/store.py @@ -66,12 +66,12 @@ def _mysql_sessions_ddl(session_table: str, owner_id_column_ddl: "str | None") - def _mysql_events_ddl(events_table: str, session_table: str) -> str: - """Generate shared MySQL events CREATE TABLE DDL (post clean-break, 5 columns).""" + """Generate shared MySQL events CREATE TABLE DDL.""" return f""" CREATE TABLE IF NOT EXISTS {events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {session_table}(id) ON DELETE CASCADE, @@ -192,15 +192,19 @@ async def create_session( await cursor.close() await conn.commit() - return await self.get_session(session_id) # type: ignore[return-value] + result = await self.get_session(app_name, user_id, session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ try: @@ -208,11 +212,11 @@ async def get_session( cursor = await conn.cursor() try: if renew_for is not None and self._calculate_expires_at(renew_for) is not None: - update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE id = %s" - await cursor.execute(update_sql, (session_id,)) + update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE app_name = %s AND user_id = %s AND id = %s" + await cursor.execute(update_sql, (app_name, user_id, session_id)) await conn.commit() - await cursor.execute(sql, (session_id,)) + await cursor.execute(sql, (app_name, user_id, session_id)) row = await cursor.fetchone() finally: await cursor.close() @@ -235,30 +239,30 @@ async def get_session( return None raise - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: state_json = to_json(state) sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ async with self._config.provide_connection() as conn: cursor = await conn.cursor() try: - await cursor.execute(sql, (state_json, session_id)) + await cursor.execute(sql, (state_json, app_name, user_id, session_id)) finally: await cursor.close() await conn.commit() - async def delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = %s" + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" async with self._config.provide_connection() as conn: cursor = await conn.cursor() try: - await cursor.execute(sql, (session_id,)) + await cursor.execute(sql, (app_name, user_id, session_id)) finally: await cursor.close() await conn.commit() @@ -307,18 +311,12 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis raise async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record with 5 keys (session_id, invocation_id, - author, timestamp, event_data). - """ event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -328,9 +326,9 @@ async def append_event(self, event_record: EventRecord) -> None: await cursor.execute( sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data_str, ), @@ -342,40 +340,35 @@ async def append_event(self, event_record: EventRecord) -> None: async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update session + scoped state. - - MySQL doesn't support UPDATE...RETURNING; the UPDATE is followed by a - SELECT inside the same transaction so callers get the refreshed row - without acquiring a second connection. - """ + """Atomically append an event and update session + scoped state.""" event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data state_json = to_json(state) insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ update_sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ app_upsert_sql = f""" @@ -396,25 +389,19 @@ async def append_event_and_update_state( await cursor.execute( insert_sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data_str, ), ) - await cursor.execute(update_sql, (state_json, session_id)) - await cursor.execute(select_sql, (session_id,)) + await cursor.execute(update_sql, (state_json, app_name, user_id, session_id)) + await cursor.execute(select_sql, (app_name, user_id, session_id)) row = await cursor.fetchone() if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) await cursor.execute(app_upsert_sql, (app_name, to_json(app_state))) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) await cursor.execute(user_upsert_sql, (app_name, user_id, to_json(user_state))) finally: await cursor.close() @@ -435,33 +422,30 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - """ - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + """Get events for a session.""" + where_clauses = ["s.app_name = %s", "s.user_id = %s", "e.session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > %s") + where_clauses.append("e.timestamp > %s") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -475,11 +459,13 @@ async def get_events( return [ EventRecord( - session_id=cast("str", row[0]), - invocation_id=cast("str", row[1]), - author=cast("str", row[2]), + id=cast("str", row[0]), + session_id=cast("str", row[1]), + invocation_id=cast("str", row[2]), timestamp=cast("datetime", row[3]), event_data=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]), + app_name=cast("str", row[5]), + user_id=cast("str", row[6]), ) for row in rows ] @@ -689,18 +675,18 @@ async def create_session( return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) + return await async_(self._get_session)(app_name, user_id, session_id, renew_for) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" - await async_(self._update_session_state)(session_id, state) + await async_(self._update_session_state)(app_name, user_id, session_id, state) - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and associated events.""" - await async_(self._delete_session)(session_id) + await async_(self._delete_session)(app_name, user_id, session_id) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" @@ -709,30 +695,29 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically append an event and update session + scoped state.""" return await async_(self._append_event_and_update_state)( - event_record, - session_id, - state, - app_name=app_name, - user_id=user_id, - app_state=app_state, - user_state=user_state, + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) + return await async_(self._get_events)(app_name, user_id, session_id, after_timestamp, limit) async def delete_expired_events(self, before: "datetime") -> int: """Delete events older than the given timestamp.""" @@ -845,17 +830,19 @@ def _create_session( cursor.close() conn.commit() - result = self._get_session(session_id) + result = self._get_session(app_name, user_id, session_id) if result is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return result - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ try: @@ -863,11 +850,11 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No cursor = conn.cursor() try: if renew_for is not None and self._calculate_expires_at(renew_for) is not None: - update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE id = %s" - cursor.execute(update_sql, (session_id,)) + update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE app_name = %s AND user_id = %s AND id = %s" + cursor.execute(update_sql, (app_name, user_id, session_id)) conn.commit() - cursor.execute(sql, (session_id,)) + cursor.execute(sql, (app_name, user_id, session_id)) row = cursor.fetchone() finally: cursor.close() @@ -890,30 +877,30 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No return None raise - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: state_json = to_json(state) sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, (state_json, session_id)) + cursor.execute(sql, (state_json, app_name, user_id, session_id)) finally: cursor.close() conn.commit() - def _delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = %s" + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, (session_id,)) + cursor.execute(sql, (app_name, user_id, session_id)) finally: cursor.close() conn.commit() @@ -964,40 +951,35 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses def _append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically create an event and update session + scoped state. - - MySQL doesn't support UPDATE...RETURNING; the UPDATE is followed by a - SELECT inside the same transaction so callers get the refreshed row - without acquiring a second connection. - """ + """Atomically create an event and update session + scoped state.""" event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data state_json = to_json(state) insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ update_sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ app_upsert_sql = f""" @@ -1018,25 +1000,19 @@ def _append_event_and_update_state( cursor.execute( insert_sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data_str, ), ) - cursor.execute(update_sql, (state_json, session_id)) - cursor.execute(select_sql, (session_id,)) + cursor.execute(update_sql, (state_json, app_name, user_id, session_id)) + cursor.execute(select_sql, (app_name, user_id, session_id)) row = cursor.fetchone() if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) cursor.execute(app_upsert_sql, (app_name, to_json(app_state))) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) cursor.execute(user_upsert_sql, (app_name, user_id, to_json(user_state))) finally: cursor.close() @@ -1062,7 +1038,7 @@ def _insert_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -1072,9 +1048,9 @@ def _insert_event(self, event_record: EventRecord) -> None: cursor.execute( sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data_str, ), @@ -1084,32 +1060,29 @@ def _insert_event(self, event_record: EventRecord) -> None: conn.commit() def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - """List events for a session ordered by timestamp. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - """ - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + """List events for a session ordered by timestamp.""" + where_clauses = ["s.app_name = %s", "s.user_id = %s", "e.session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > %s") + where_clauses.append("e.timestamp > %s") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) limit_clause = " LIMIT %s" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ if limit: params.append(limit) @@ -1125,11 +1098,13 @@ def _get_events( return [ EventRecord( - session_id=cast("str", row[0]), - invocation_id=cast("str", row[1]), - author=cast("str", row[2]), + id=cast("str", row[0]), + session_id=cast("str", row[1]), + invocation_id=cast("str", row[2]), timestamp=cast("datetime", row[3]), event_data=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]), + app_name=cast("str", row[5]), + user_id=cast("str", row[6]), ) for row in rows ] diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index 9a20d733c..0548e4e69 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -189,14 +189,16 @@ async def create_session( await cursor.execute(sql, params) await conn.commit() - return await self.get_session(session_id) # type: ignore[return-value] + return await self.get_session(app_name, user_id, session_id) # type: ignore[return-value] async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": """Get session by ID. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. renew_for: If positive, touch update_time while reading. @@ -213,8 +215,8 @@ async def get_session( cursor = conn.cursor() if renew_for is not None and self._calculate_expires_at(renew_for) is not None: await cursor.execute( - f"UPDATE {self._session_table} SET update_time = SYSTIMESTAMP WHERE id = :id", - {"id": session_id}, + f"UPDATE {self._session_table} SET update_time = SYSTIMESTAMP WHERE app_name = :app_name AND user_id = :user_id AND id = :id", + {"app_name": app_name, "user_id": user_id, "id": session_id}, ) await conn.commit() @@ -222,9 +224,9 @@ async def get_session( f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = :id + WHERE app_name = :app_name AND user_id = :user_id AND id = :id """, - {"id": session_id}, + {"app_name": app_name, "user_id": user_id, "id": session_id}, ) row = await cursor.fetchone() @@ -249,10 +251,12 @@ async def get_session( return None raise - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. state: New state dictionary (replaces existing state). @@ -266,28 +270,30 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - sql = f""" UPDATE {self._session_table} SET state = :state, update_time = SYSTIMESTAMP - WHERE id = :id + WHERE app_name = :app_name AND user_id = :user_id AND id = :id """ async with self._config.provide_connection() as conn: cursor = conn.cursor() - await cursor.execute(sql, {"state": state_data, "id": session_id}) + await cursor.execute(sql, {"state": state_data, "app_name": app_name, "user_id": user_id, "id": session_id}) await conn.commit() - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and all associated events (cascade). Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. Notes: Foreign key constraint ensures events are cascade-deleted. """ - sql = f"DELETE FROM {self._session_table} WHERE id = :id" + sql = f"DELETE FROM {self._session_table} WHERE app_name = :app_name AND user_id = :user_id AND id = :id" async with self._config.provide_connection() as conn: cursor = conn.cursor() - await cursor.execute(sql, {"id": session_id}) + await cursor.execute(sql, {"app_name": app_name, "user_id": user_id, "id": session_id}) await conn.commit() async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": @@ -353,14 +359,13 @@ async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session. Args: - event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_data. + event_record: Event record. """ sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_data + :id, :session_id, :invocation_id, :timestamp, :event_data ) """ @@ -369,9 +374,9 @@ async def append_event(self, event_record: EventRecord) -> None: await cursor.execute( sql, { + "id": event_record["id"], "session_id": event_record["session_id"], "invocation_id": event_record["invocation_id"], - "author": event_record["author"], "timestamp": event_record["timestamp"], "event_data": await self._serialize_event_data(event_record["event_data"]), }, @@ -381,26 +386,24 @@ async def append_event(self, event_record: EventRecord) -> None: async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically append an event and update session + scoped state. All writes are executed within a single transaction so they succeed or - fail together. The refreshed SessionRecord is read inside the same - transaction (Oracle's RETURNING INTO requires output bind variables - which complicate async cursor handling, so SELECT-after-UPDATE is used). + fail together. """ insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_data + :id, :session_id, :invocation_id, :timestamp, :event_data ) """ @@ -408,13 +411,13 @@ async def append_event_and_update_state( update_sql = f""" UPDATE {self._session_table} SET state = :state, update_time = SYSTIMESTAMP - WHERE id = :id + WHERE app_name = :app_name AND user_id = :user_id AND id = :id """ select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = :id + WHERE app_name = :app_name AND user_id = :user_id AND id = :id """ app_upsert_sql = f""" @@ -444,27 +447,23 @@ async def append_event_and_update_state( await cursor.execute( insert_sql, { + "id": event_record["id"], "session_id": event_record["session_id"], "invocation_id": event_record["invocation_id"], - "author": event_record["author"], "timestamp": event_record["timestamp"], "event_data": await self._serialize_event_data(event_record["event_data"]), }, ) - await cursor.execute(update_sql, {"state": state_data, "id": session_id}) - await cursor.execute(select_sql, {"id": session_id}) + await cursor.execute( + update_sql, {"state": state_data, "app_name": app_name, "user_id": user_id, "id": session_id} + ) + await cursor.execute(select_sql, {"app_name": app_name, "user_id": user_id, "id": session_id}) row = await cursor.fetchone() if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) await cursor.execute( app_upsert_sql, {"app_name": app_name, "state": await self._serialize_state(app_state)} ) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) await cursor.execute( user_upsert_sql, {"app_name": app_name, "user_id": user_id, "state": await self._serialize_state(user_state)}, @@ -486,11 +485,18 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. after_timestamp: Only return events after this time. limit: Maximum number of events to return. @@ -499,11 +505,11 @@ async def get_events( List of event records ordered by timestamp ASC. """ - where_clauses = ["session_id = :session_id"] - params: dict[str, Any] = {"session_id": session_id} + where_clauses = ["s.app_name = :app_name", "s.user_id = :user_id", "e.session_id = :session_id"] + params: dict[str, Any] = {"app_name": app_name, "user_id": user_id, "session_id": session_id} if after_timestamp is not None: - where_clauses.append("timestamp > :after_timestamp") + where_clauses.append("e.timestamp > :after_timestamp") params["after_timestamp"] = after_timestamp where_clause = " AND ".join(where_clauses) @@ -512,10 +518,11 @@ async def get_events( limit_clause = f" FETCH FIRST {limit} ROWS ONLY" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -526,11 +533,13 @@ async def get_events( return [ EventRecord( - session_id=row[0], - invocation_id=_oracle_text_value(row[1]), - author=_oracle_text_value(row[2]), + id=row[0], + session_id=row[1], + invocation_id=_oracle_text_value(row[2]), timestamp=row[3], event_data=await self._deserialize_json_field(row[4]) or {}, + app_name=row[5], + user_id=row[6], ) for row in rows ] @@ -929,19 +938,15 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - """ event_data_col = _event_data_column_ddl(storage_type) table_clauses = _oracle_table_feature_clauses( - self._config, - "events", - in_memory=self._in_memory, - hash_partition_key="session_id", - range_partition_key="timestamp", + self._config, "events", in_memory=self._in_memory, hash_partition_key="id", range_partition_key="timestamp" ) return f""" BEGIN EXECUTE IMMEDIATE 'CREATE TABLE {self._events_table} ( + id VARCHAR2(128) PRIMARY KEY, session_id VARCHAR2(128) NOT NULL, invocation_id VARCHAR2(256), - author VARCHAR2(256), timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, {event_data_col}, CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) @@ -1166,18 +1171,18 @@ async def create_session( return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) + return await async_(self._get_session)(app_name, user_id, session_id, renew_for) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" - await async_(self._update_session_state)(session_id, state) + await async_(self._update_session_state)(app_name, user_id, session_id, state) - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and associated events.""" - await async_(self._delete_session)(session_id) + await async_(self._delete_session)(app_name, user_id, session_id) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" @@ -1186,30 +1191,29 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically append an event and update session + scoped state.""" return await async_(self._append_event_and_update_state)( - event_record, - session_id, - state, - app_name=app_name, - user_id=user_id, - app_state=app_state, - user_state=user_state, + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) + return await async_(self._get_events)(app_name, user_id, session_id, after_timestamp, limit) async def delete_expired_events(self, before: "datetime") -> int: """Delete events older than the given timestamp.""" @@ -1739,16 +1743,20 @@ def _create_session( cursor.execute(sql, params) conn.commit() - result = self._get_session(session_id) + result = self._get_session(app_name, user_id, session_id) if result is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return result - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. renew_for: If positive, touch update_time while reading. @@ -1763,7 +1771,7 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = :id + WHERE app_name = :app_name AND user_id = :user_id AND id = :id """ try: @@ -1771,12 +1779,12 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No cursor = conn.cursor() if renew_for is not None and self._calculate_expires_at(renew_for) is not None: cursor.execute( - f"UPDATE {self._session_table} SET update_time = SYSTIMESTAMP WHERE id = :id", - {"id": session_id}, + f"UPDATE {self._session_table} SET update_time = SYSTIMESTAMP WHERE app_name = :app_name AND user_id = :user_id AND id = :id", + {"app_name": app_name, "user_id": user_id, "id": session_id}, ) conn.commit() - cursor.execute(sql, {"id": session_id}) + cursor.execute(sql, {"app_name": app_name, "user_id": user_id, "id": session_id}) row = cursor.fetchone() if row is None: @@ -1800,10 +1808,12 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No return None raise - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. state: New state dictionary (replaces existing state). @@ -1817,28 +1827,30 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non sql = f""" UPDATE {self._session_table} SET state = :state, update_time = SYSTIMESTAMP - WHERE id = :id + WHERE app_name = :app_name AND user_id = :user_id AND id = :id """ with self._config.provide_connection() as conn: cursor = conn.cursor() - cursor.execute(sql, {"state": state_data, "id": session_id}) + cursor.execute(sql, {"state": state_data, "app_name": app_name, "user_id": user_id, "id": session_id}) conn.commit() - def _delete_session(self, session_id: str) -> None: + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and all associated events (cascade). Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. Notes: Foreign key constraint ensures events are cascade-deleted. """ - sql = f"DELETE FROM {self._session_table} WHERE id = :id" + sql = f"DELETE FROM {self._session_table} WHERE app_name = :app_name AND user_id = :user_id AND id = :id" with self._config.provide_connection() as conn: cursor = conn.cursor() - cursor.execute(sql, {"id": session_id}) + cursor.execute(sql, {"app_name": app_name, "user_id": user_id, "id": session_id}) conn.commit() def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": @@ -1903,25 +1915,20 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses def _append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically create an event and update session + scoped state. - - All writes are executed within a single transaction so they succeed or - fail together; the refreshed SessionRecord is read inside the same - transaction. - """ + """Atomically create an event and update session + scoped state.""" insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_data + :id, :session_id, :invocation_id, :timestamp, :event_data ) """ @@ -1929,13 +1936,13 @@ def _append_event_and_update_state( update_sql = f""" UPDATE {self._session_table} SET state = :state, update_time = SYSTIMESTAMP - WHERE id = :id + WHERE app_name = :app_name AND user_id = :user_id AND id = :id """ select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = :id + WHERE app_name = :app_name AND user_id = :user_id AND id = :id """ app_upsert_sql = f""" @@ -1965,25 +1972,21 @@ def _append_event_and_update_state( cursor.execute( insert_sql, { + "id": event_record["id"], "session_id": event_record["session_id"], "invocation_id": event_record["invocation_id"], - "author": event_record["author"], "timestamp": event_record["timestamp"], "event_data": self._serialize_event_data(event_record["event_data"]), }, ) - cursor.execute(update_sql, {"state": state_data, "id": session_id}) - cursor.execute(select_sql, {"id": session_id}) + cursor.execute( + update_sql, {"state": state_data, "app_name": app_name, "user_id": user_id, "id": session_id} + ) + cursor.execute(select_sql, {"app_name": app_name, "user_id": user_id, "id": session_id}) row = cursor.fetchone() if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) cursor.execute(app_upsert_sql, {"app_name": app_name, "state": self._serialize_state(app_state)}) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) cursor.execute( user_upsert_sql, {"app_name": app_name, "user_id": user_id, "state": self._serialize_state(user_state)}, @@ -2005,11 +2008,18 @@ def _append_event_and_update_state( ) def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """List events for a session ordered by timestamp. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. after_timestamp: Only return events after this time. limit: Maximum number of events to return. @@ -2018,20 +2028,21 @@ def _get_events( List of event records ordered by timestamp ASC. """ - where_clauses = ["session_id = :session_id"] - params: dict[str, Any] = {"session_id": session_id} + where_clauses = ["s.app_name = :app_name", "s.user_id = :user_id", "e.session_id = :session_id"] + params: dict[str, Any] = {"app_name": app_name, "user_id": user_id, "session_id": session_id} if after_timestamp is not None: - where_clauses.append("timestamp > :after_timestamp") + where_clauses.append("e.timestamp > :after_timestamp") params["after_timestamp"] = after_timestamp where_clause = " AND ".join(where_clauses) limit_clause = f" FETCH FIRST {limit} ROWS ONLY" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -2042,11 +2053,13 @@ def _get_events( return [ EventRecord( - session_id=row[0], - invocation_id=_oracle_text_value(row[1]), - author=_oracle_text_value(row[2]), + id=row[0], + session_id=row[1], + invocation_id=_oracle_text_value(row[2]), timestamp=row[3], event_data=self._deserialize_json_field(row[4]) or {}, + app_name=row[5], + user_id=row[6], ) for row in rows ] @@ -2090,9 +2103,9 @@ def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_data + :id, :session_id, :invocation_id, :timestamp, :event_data ) """ @@ -2101,9 +2114,9 @@ def _append_event(self, event_record: EventRecord) -> None: cursor.execute( sql, { + "id": event_record["id"], "session_id": event_record["session_id"], "invocation_id": event_record["invocation_id"], - "author": event_record["author"], "timestamp": event_record["timestamp"], "event_data": self._serialize_event_data(event_record["event_data"]), }, diff --git a/sqlspec/adapters/psqlpy/adk/store.py b/sqlspec/adapters/psqlpy/adk/store.py index 25dac6d71..91658f7f2 100644 --- a/sqlspec/adapters/psqlpy/adk/store.py +++ b/sqlspec/adapters/psqlpy/adk/store.py @@ -78,28 +78,32 @@ async def create_session( """ await conn.execute(sql, [session_id, app_name, user_id, state]) - return await self.get_session(session_id) # type: ignore[return-value] + res = await self.get_session(app_name, user_id, session_id) + if res is None: + msg = "Failed to retrieve created session." + raise RuntimeError(msg) + return res async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": if renew_for is not None and self._calculate_expires_at(renew_for) is not None: sql = f""" UPDATE {self._session_table} SET update_time = CURRENT_TIMESTAMP - WHERE id = $1 + WHERE app_name = $1 AND user_id = $2 AND id = $3 RETURNING id, app_name, user_id, state, create_time, update_time """ else: sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = $1 - """ + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = $1 AND user_id = $2 AND id = $3 + """ try: async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] - result = await conn.fetch(sql, [session_id]) + result = await conn.fetch(sql, [app_name, user_id, session_id]) rows: list[dict[str, Any]] = result.result() if result else [] if not rows: @@ -120,21 +124,21 @@ async def get_session( return None raise - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP - WHERE id = $2 + WHERE app_name = $2 AND user_id = $3 AND id = $4 """ async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] - await conn.execute(sql, [state, session_id]) + await conn.execute(sql, [state, app_name, user_id, session_id]) - async def delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = $1" + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = $1 AND user_id = $2 AND id = $3" async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] - await conn.execute(sql, [session_id]) + await conn.execute(sql, [app_name, user_id, session_id]) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: @@ -179,7 +183,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ @@ -187,9 +191,9 @@ async def append_event(self, event_record: EventRecord) -> None: await conn.execute( sql, [ + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_record["event_data"], ], @@ -198,23 +202,23 @@ async def append_event(self, event_record: EventRecord) -> None: async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ update_sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP - WHERE id = $2 + WHERE app_name = $2 AND user_id = $3 AND id = $4 RETURNING id, app_name, user_id, state, create_time, update_time """ app_upsert_sql = f""" @@ -236,24 +240,18 @@ async def append_event_and_update_state( await conn.execute( insert_sql, [ + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_record["event_data"], ], ) - result = await conn.fetch(update_sql, [state, session_id]) + result = await conn.fetch(update_sql, [state, app_name, user_id, session_id]) rows: list[dict[str, Any]] = result.result() if result else [] if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) await conn.execute(app_upsert_sql, [app_name, app_state]) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) await conn.execute(user_upsert_sql, [app_name, user_id, user_state]) if not rows: @@ -271,13 +269,18 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = $1"] - params: list[Any] = [session_id] + where_clauses = ["s.app_name = $1", "s.user_id = $2", "e.session_id = $3"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append(f"timestamp > ${len(params) + 1}") + where_clauses.append(f"e.timestamp > ${len(params) + 1}") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) @@ -286,10 +289,11 @@ async def get_events( params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -299,11 +303,13 @@ async def get_events( return [ EventRecord( + id=row["id"], session_id=row["session_id"], invocation_id=row["invocation_id"], - author=row["author"], timestamp=row["timestamp"], event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], ) for row in rows ] @@ -452,9 +458,9 @@ async def _get_create_sessions_table_sql(self) -> str: async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py index 9d09f6168..15def115f 100644 --- a/sqlspec/adapters/psycopg/adk/store.py +++ b/sqlspec/adapters/psycopg/adk/store.py @@ -113,28 +113,30 @@ async def create_session( async with self._config.provide_connection() as conn, conn.cursor() as cur: await cur.execute(query, params) - return await self.get_session(session_id) # type: ignore[return-value] + return await self.get_session(app_name, user_id, session_id) # type: ignore[return-value] async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": if renew_for is not None and self._calculate_expires_at(renew_for) is not None: query = pg_sql.SQL(""" UPDATE {table} SET update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s RETURNING id, app_name, user_id, state, create_time, update_time """).format(table=pg_sql.Identifier(self._session_table)) + params = (app_name, user_id, session_id) else: query = pg_sql.SQL(""" SELECT id, app_name, user_id, state, create_time, update_time FROM {table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """).format(table=pg_sql.Identifier(self._session_table)) + params = (app_name, user_id, session_id) try: async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(query, (session_id,)) + await cur.execute(query, params) row = await cur.fetchone() if row is None: @@ -151,21 +153,23 @@ async def get_session( except errors.UndefinedTable: return None - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: query = pg_sql.SQL(""" UPDATE {table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """).format(table=pg_sql.Identifier(self._session_table)) async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(query, (Jsonb(state), session_id)) + await cur.execute(query, (Jsonb(state), app_name, user_id, session_id)) - async def delete_session(self, session_id: str) -> None: - query = pg_sql.SQL("DELETE FROM {table} WHERE id = %s").format(table=pg_sql.Identifier(self._session_table)) + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + query = pg_sql.SQL("DELETE FROM {table} WHERE app_name = %s AND user_id = %s AND id = %s").format( + table=pg_sql.Identifier(self._session_table) + ) async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(query, (session_id,)) + await cur.execute(query, (app_name, user_id, session_id)) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: @@ -207,7 +211,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: query = pg_sql.SQL(""" INSERT INTO {table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) @@ -218,9 +222,9 @@ async def append_event(self, event_record: EventRecord) -> None: await cur.execute( query, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], jsonb_value, ), @@ -229,24 +233,24 @@ async def append_event(self, event_record: EventRecord) -> None: async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_query = pg_sql.SQL(""" INSERT INTO {table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) update_query = pg_sql.SQL(""" UPDATE {table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s RETURNING id, app_name, user_id, state, create_time, update_time """).format(table=pg_sql.Identifier(self._session_table)) @@ -273,24 +277,18 @@ async def append_event_and_update_state( await cur.execute( insert_query, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], jsonb_value, ), ) - await cur.execute(update_query, (Jsonb(state), session_id)) + await cur.execute(update_query, (Jsonb(state), app_name, user_id, session_id)) row = await cur.fetchone() if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) await cur.execute(app_upsert_query, (app_name, Jsonb(app_state))) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) await cur.execute(user_upsert_query, (app_name, user_id, Jsonb(user_state))) await conn.commit() @@ -308,30 +306,37 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + where_clauses = [pg_sql.SQL("s.app_name = %s"), pg_sql.SQL("s.user_id = %s"), pg_sql.SQL("e.session_id = %s")] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > %s") + where_clauses.append(pg_sql.SQL("e.timestamp > %s")) params.append(after_timestamp) - where_clause = " AND ".join(where_clauses) + where_clause = pg_sql.SQL(" AND ").join(where_clauses) if limit: params.append(limit) query = pg_sql.SQL( """ - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {events_table} e + JOIN {session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ ).format( - table=pg_sql.Identifier(self._events_table), - where_clause=pg_sql.SQL(where_clause), # pyright: ignore[reportArgumentType] - limit_clause=pg_sql.SQL(" LIMIT %s" if limit else ""), # pyright: ignore[reportArgumentType] + events_table=pg_sql.Identifier(self._events_table), + session_table=pg_sql.Identifier(self._session_table), + where_clause=where_clause, + limit_clause=pg_sql.SQL(" LIMIT %s" if limit else ""), ) try: @@ -341,11 +346,13 @@ async def get_events( return [ EventRecord( + id=row["id"], session_id=row["session_id"], invocation_id=row["invocation_id"], - author=row["author"], timestamp=row["timestamp"], event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], ) for row in rows ] @@ -480,9 +487,9 @@ async def _get_create_sessions_table_sql(self) -> str: async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE @@ -583,18 +590,18 @@ async def create_session( return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) + return await async_(self._get_session)(app_name, user_id, session_id, renew_for) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" - await async_(self._update_session_state)(session_id, state) + await async_(self._update_session_state)(app_name, user_id, session_id, state) - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and associated events.""" - await async_(self._delete_session)(session_id) + await async_(self._delete_session)(app_name, user_id, session_id) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" @@ -603,30 +610,29 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically append an event and update session + scoped state.""" return await async_(self._append_event_and_update_state)( - event_record, - session_id, - state, - app_name=app_name, - user_id=user_id, - app_state=app_state, - user_state=user_state, + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) + return await async_(self._get_events)(app_name, user_id, session_id, after_timestamp, limit) async def delete_expired_events(self, before: "datetime") -> int: """Delete events older than the given timestamp.""" @@ -693,9 +699,9 @@ async def _get_create_sessions_table_sql(self) -> str: async def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE @@ -789,30 +795,34 @@ def _create_session( with self._config.provide_connection() as conn, conn.cursor() as cur: cur.execute(query, params) - result = self._get_session(session_id) + result = self._get_session(app_name, user_id, session_id) if result is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return result - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": if renew_for is not None and self._calculate_expires_at(renew_for) is not None: query = pg_sql.SQL(""" UPDATE {table} SET update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s RETURNING id, app_name, user_id, state, create_time, update_time """).format(table=pg_sql.Identifier(self._session_table)) + params = (app_name, user_id, session_id) else: query = pg_sql.SQL(""" SELECT id, app_name, user_id, state, create_time, update_time FROM {table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """).format(table=pg_sql.Identifier(self._session_table)) + params = (app_name, user_id, session_id) try: with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(query, (session_id,)) + cur.execute(query, params) row = cur.fetchone() if row is None: @@ -829,21 +839,23 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No except errors.UndefinedTable: return None - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: query = pg_sql.SQL(""" UPDATE {table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """).format(table=pg_sql.Identifier(self._session_table)) with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(query, (Jsonb(state), session_id)) + cur.execute(query, (Jsonb(state), app_name, user_id, session_id)) - def _delete_session(self, session_id: str) -> None: - query = pg_sql.SQL("DELETE FROM {table} WHERE id = %s").format(table=pg_sql.Identifier(self._session_table)) + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + query = pg_sql.SQL("DELETE FROM {table} WHERE app_name = %s AND user_id = %s AND id = %s").format( + table=pg_sql.Identifier(self._session_table) + ) with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(query, (session_id,)) + cur.execute(query, (app_name, user_id, session_id)) def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: @@ -885,7 +897,7 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses def _insert_event(self, event_record: EventRecord) -> None: insert_query = pg_sql.SQL(""" INSERT INTO {table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) @@ -896,9 +908,9 @@ def _insert_event(self, event_record: EventRecord) -> None: cur.execute( insert_query, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], jsonb_value, ), @@ -908,24 +920,24 @@ def _insert_event(self, event_record: EventRecord) -> None: def _append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_query = pg_sql.SQL(""" INSERT INTO {table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) update_query = pg_sql.SQL(""" UPDATE {table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s RETURNING id, app_name, user_id, state, create_time, update_time """).format(table=pg_sql.Identifier(self._session_table)) @@ -952,24 +964,18 @@ def _append_event_and_update_state( cur.execute( insert_query, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], jsonb_value, ), ) - cur.execute(update_query, (Jsonb(state), session_id)) + cur.execute(update_query, (Jsonb(state), app_name, user_id, session_id)) row = cur.fetchone() if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) cur.execute(app_upsert_query, (app_name, Jsonb(app_state))) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) cur.execute(user_upsert_query, (app_name, user_id, Jsonb(user_state))) conn.commit() @@ -987,30 +993,37 @@ def _append_event_and_update_state( ) def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + where_clauses = [pg_sql.SQL("s.app_name = %s"), pg_sql.SQL("s.user_id = %s"), pg_sql.SQL("e.session_id = %s")] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > %s") + where_clauses.append(pg_sql.SQL("e.timestamp > %s")) params.append(after_timestamp) - where_clause = " AND ".join(where_clauses) + where_clause = pg_sql.SQL(" AND ").join(where_clauses) if limit: params.append(limit) query = pg_sql.SQL( """ - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {events_table} e + JOIN {session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ ).format( - table=pg_sql.Identifier(self._events_table), - where_clause=pg_sql.SQL(where_clause), # pyright: ignore[reportArgumentType] - limit_clause=pg_sql.SQL(" LIMIT %s" if limit else ""), # pyright: ignore[reportArgumentType] + events_table=pg_sql.Identifier(self._events_table), + session_table=pg_sql.Identifier(self._session_table), + where_clause=where_clause, + limit_clause=pg_sql.SQL(" LIMIT %s" if limit else ""), ) try: @@ -1020,11 +1033,13 @@ def _get_events( return [ EventRecord( + id=row["id"], session_id=row["session_id"], invocation_id=row["invocation_id"], - author=row["author"], timestamp=row["timestamp"], event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], ) for row in rows ] diff --git a/sqlspec/adapters/pymysql/adk/store.py b/sqlspec/adapters/pymysql/adk/store.py index 354d159f6..1a1915ce7 100644 --- a/sqlspec/adapters/pymysql/adk/store.py +++ b/sqlspec/adapters/pymysql/adk/store.py @@ -68,18 +68,18 @@ async def create_session( return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) + return await async_(self._get_session)(app_name, user_id, session_id, renew_for) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" - await async_(self._update_session_state)(session_id, state) + await async_(self._update_session_state)(app_name, user_id, session_id, state) - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and associated events.""" - await async_(self._delete_session)(session_id) + await async_(self._delete_session)(app_name, user_id, session_id) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" @@ -88,30 +88,29 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically append an event and update session + scoped state.""" return await async_(self._append_event_and_update_state)( - event_record, - session_id, - state, - app_name=app_name, - user_id=user_id, - app_state=app_state, - user_state=user_state, + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) + return await async_(self._get_events)(app_name, user_id, session_id, after_timestamp, limit) async def delete_expired_events(self, before: "datetime") -> int: """Delete events older than the given timestamp.""" @@ -183,9 +182,9 @@ async def _get_create_events_table_sql(self) -> str: """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, @@ -288,17 +287,19 @@ def _create_session( cursor.close() conn.commit() - result = self._get_session(session_id) + result = self._get_session(app_name, user_id, session_id) if result is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return result - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ try: @@ -306,11 +307,11 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No cursor = conn.cursor() try: if renew_for is not None and self._calculate_expires_at(renew_for) is not None: - update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE id = %s" - cursor.execute(update_sql, (session_id,)) + update_sql = f"UPDATE {self._session_table} SET update_time = UTC_TIMESTAMP(6) WHERE app_name = %s AND user_id = %s AND id = %s" + cursor.execute(update_sql, (app_name, user_id, session_id)) conn.commit() - cursor.execute(sql, (session_id,)) + cursor.execute(sql, (app_name, user_id, session_id)) row = cursor.fetchone() finally: cursor.close() @@ -318,12 +319,12 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No if row is None: return None - session_id_val, app_name, user_id, state_json, create_time, update_time = row + session_id_val, app_name_val, user_id_val, state_json, create_time, update_time = row return SessionRecord( id=session_id_val, - app_name=app_name, - user_id=user_id, + app_name=app_name_val, + user_id=user_id_val, state=from_json(state_json) if isinstance(state_json, str) else state_json, create_time=create_time, update_time=update_time, @@ -333,30 +334,30 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No return None raise - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: state_json = to_json(state) sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, (state_json, session_id)) + cursor.execute(sql, (state_json, app_name, user_id, session_id)) finally: cursor.close() conn.commit() - def _delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = %s" + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, (session_id,)) + cursor.execute(sql, (app_name, user_id, session_id)) finally: cursor.close() conn.commit() @@ -407,40 +408,35 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses def _append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically create an event and update session + scoped state. - - MySQL doesn't support UPDATE...RETURNING; the UPDATE is followed by a - SELECT inside the same transaction so callers get the refreshed row - without acquiring a second connection. - """ + """Atomically create an event and update session + scoped state.""" event_data = event_record["event_data"] event_data_str = to_json(event_data) if not isinstance(event_data, str) else event_data state_json = to_json(state) insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ update_sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ app_upsert_sql = f""" @@ -461,25 +457,19 @@ def _append_event_and_update_state( cursor.execute( insert_sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data_str, ), ) - cursor.execute(update_sql, (state_json, session_id)) - cursor.execute(select_sql, (session_id,)) + cursor.execute(update_sql, (state_json, app_name, user_id, session_id)) + cursor.execute(select_sql, (app_name, user_id, session_id)) row = cursor.fetchone() if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) cursor.execute(app_upsert_sql, (app_name, to_json(app_state))) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) cursor.execute(user_upsert_sql, (app_name, user_id, to_json(user_state))) finally: cursor.close() @@ -505,7 +495,7 @@ def _insert_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_data + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ @@ -515,9 +505,9 @@ def _insert_event(self, event_record: EventRecord) -> None: cursor.execute( sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], event_data_str, ), @@ -527,32 +517,29 @@ def _insert_event(self, event_record: EventRecord) -> None: conn.commit() def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - """List events for a session ordered by timestamp. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - """ - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + """List events for a session ordered by timestamp.""" + where_clauses = ["s.app_name = %s", "s.user_id = %s", "e.session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > %s") + where_clauses.append("e.timestamp > %s") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) limit_clause = " LIMIT %s" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ if limit: params.append(limit) @@ -568,11 +555,13 @@ def _get_events( return [ EventRecord( - session_id=row[0], - invocation_id=row[1], - author=row[2], + id=row[0], + session_id=row[1], + invocation_id=row[2], timestamp=row[3], event_data=from_json(row[4]) if isinstance(row[4], str) else row[4], + app_name=row[5], + user_id=row[6], ) for row in rows ] diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index c74c0538d..2bb659b79 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -48,50 +48,49 @@ async def create_session( return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": """Get session by ID.""" - return await async_(self._get_session)(session_id, renew_for) + return await async_(self._get_session)(app_name, user_id, session_id, renew_for) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" - await async_(self._update_session_state)(session_id, state) + await async_(self._update_session_state)(app_name, user_id, session_id, state) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app.""" return await async_(self._list_sessions)(app_name, user_id) - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and associated events.""" - await async_(self._delete_session)(session_id) + await async_(self._delete_session)(app_name, user_id, session_id) async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically append an event and update session + scoped state.""" return await async_(self._append_event_and_update_state)( - event_record, - session_id, - state, - app_name=app_name, - user_id=user_id, - app_state=app_state, - user_state=user_state, + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) + return await async_(self._get_events)(app_name, user_id, session_id, after_timestamp, limit) async def delete_expired_events(self, before: "datetime") -> int: """Return 0 because Spanner row deletion policies own TTL cleanup.""" @@ -161,9 +160,9 @@ def _session_param_types(self, include_owner: bool) -> "dict[str, Any]": def _event_param_types(self) -> "dict[str, Any]": json_type = _json_param_type() return { + "id": SPANNER_PARAM_TYPES.STRING, "session_id": SPANNER_PARAM_TYPES.STRING, "invocation_id": SPANNER_PARAM_TYPES.STRING, - "author": SPANNER_PARAM_TYPES.STRING, "timestamp": SPANNER_PARAM_TYPES.TIMESTAMP, "event_data": json_type, } @@ -209,27 +208,44 @@ def _create_session( "update_time": datetime.now(timezone.utc), } - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": if renew_for is not None and self._calculate_expires_at(renew_for) is not None: update_sql = f""" UPDATE {self._session_table} SET update_time = PENDING_COMMIT_TIMESTAMP() - WHERE id = @id + WHERE app_name = @app_name AND user_id = @user_id AND id = @id """ if self._shard_count > 1: update_sql = f"{update_sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" - self._run_write([(update_sql, {"id": session_id}, {"id": SPANNER_PARAM_TYPES.STRING})]) + self._run_write([ + ( + update_sql, + {"app_name": app_name, "user_id": user_id, "id": session_id}, + { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "id": SPANNER_PARAM_TYPES.STRING, + }, + ) + ]) sql = f""" SELECT id, app_name, user_id, state, create_time, update_time{", " + self._owner_id_column_name if self._owner_id_column_name else ""} FROM {self._session_table} - WHERE id = @id + WHERE app_name = @app_name AND user_id = @user_id AND id = @id """ if self._shard_count > 1: sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" sql = f"{sql} LIMIT 1" - params = {"id": session_id} - rows = self._run_read(sql, params, {"id": SPANNER_PARAM_TYPES.STRING}) + params = {"app_name": app_name, "user_id": user_id, "id": session_id} + types = { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "id": SPANNER_PARAM_TYPES.STRING, + } + rows = self._run_read(sql, params, types) if not rows: return None @@ -245,17 +261,23 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No } return record - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - params = {"id": session_id, "state": to_json(state)} + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + params = {"app_name": app_name, "user_id": user_id, "id": session_id, "state": to_json(state)} json_type = _json_param_type() sql = f""" UPDATE {self._session_table} SET state = @state, update_time = PENDING_COMMIT_TIMESTAMP() - WHERE id = @id + WHERE app_name = @app_name AND user_id = @user_id AND id = @id """ if self._shard_count > 1: sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" - self._run_write([(sql, params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type})]) + types = { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "id": SPANNER_PARAM_TYPES.STRING, + "state": json_type, + } + self._run_write([(sql, params, types)]) def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": sql = f""" @@ -288,65 +310,77 @@ def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[S records.append(record) return records - def _delete_session(self, session_id: str) -> None: + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: shard_clause = ( f" AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})" if self._shard_count > 1 else "" ) delete_events_sql = f"DELETE FROM {self._events_table} WHERE session_id = @session_id{shard_clause}" - delete_session_sql = f"DELETE FROM {self._session_table} WHERE id = @session_id{shard_clause}" - params = {"session_id": session_id} - types = {"session_id": SPANNER_PARAM_TYPES.STRING} - self._run_write([(delete_events_sql, params, types), (delete_session_sql, params, types)]) + delete_session_sql = f"DELETE FROM {self._session_table} WHERE app_name = @app_name AND user_id = @user_id AND id = @session_id{shard_clause}" + params = {"app_name": app_name, "user_id": user_id, "session_id": session_id} + types = { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "session_id": SPANNER_PARAM_TYPES.STRING, + } + self._run_write([ + (delete_events_sql, {"session_id": session_id}, {"session_id": SPANNER_PARAM_TYPES.STRING}), + (delete_session_sql, params, types), + ]) def _append_event_and_update_state( self, event_record: "EventRecord", + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically insert event + update session + upsert scoped state. - - All writes execute within a single Spanner transaction so they succeed - or fail together. A follow-up single-use read returns the SessionRecord; - we can't capture update_time inside the write txn because - PENDING_COMMIT_TIMESTAMP() only materialises on commit. - """ + """Atomically insert event + update session + upsert scoped state.""" event_params: dict[str, Any] = { + "id": event_record["id"], "session_id": event_record["session_id"], "invocation_id": event_record["invocation_id"], - "author": event_record["author"], "timestamp": event_record["timestamp"], "event_data": to_json(event_record["event_data"]), } insert_sql = f""" - INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_data) - VALUES (@session_id, @invocation_id, @author, @timestamp, @event_data) + INSERT INTO {self._events_table} (id, session_id, invocation_id, timestamp, event_data) + VALUES (@id, @session_id, @invocation_id, @timestamp, @event_data) """ json_type = _json_param_type() - state_params: dict[str, Any] = {"id": session_id, "state": to_json(state)} + state_params: dict[str, Any] = { + "app_name": app_name, + "user_id": user_id, + "id": session_id, + "state": to_json(state), + } update_sql = f""" UPDATE {self._session_table} SET state = @state, update_time = PENDING_COMMIT_TIMESTAMP() - WHERE id = @id + WHERE app_name = @app_name AND user_id = @user_id AND id = @id """ if self._shard_count > 1: update_sql = f"{update_sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" statements: list[tuple[str, dict[str, Any], dict[str, Any]]] = [ (insert_sql, event_params, self._event_param_types()), - (update_sql, state_params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type}), + ( + update_sql, + state_params, + { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "id": SPANNER_PARAM_TYPES.STRING, + "state": json_type, + }, + ), ] if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) app_delete_sql = f"DELETE FROM {self._app_state_table} WHERE app_name = @app_name" app_insert_sql = f""" INSERT INTO {self._app_state_table} (app_name, state, update_time) @@ -359,9 +393,6 @@ def _append_event_and_update_state( {"app_name": SPANNER_PARAM_TYPES.STRING, "state": json_type}, )) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) user_delete_sql = f"DELETE FROM {self._user_state_table} WHERE app_name = @app_name AND user_id = @user_id" user_insert_sql = f""" INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) @@ -380,7 +411,7 @@ def _append_event_and_update_state( self._run_write(statements) - record = self._get_session(session_id) + record = self._get_session(app_name, user_id, session_id) if record is None: msg = f"Session {session_id} not found during append_event_and_update_state." raise ValueError(msg) @@ -388,35 +419,45 @@ def _append_event_and_update_state( def _insert_event(self, event_record: "EventRecord") -> None: event_params: dict[str, Any] = { + "id": event_record["id"], "session_id": event_record["session_id"], "invocation_id": event_record["invocation_id"], - "author": event_record["author"], "timestamp": event_record["timestamp"], "event_data": to_json(event_record["event_data"]), } insert_sql = f""" - INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_data) - VALUES (@session_id, @invocation_id, @author, @timestamp, @event_data) + INSERT INTO {self._events_table} (id, session_id, invocation_id, timestamp, event_data) + VALUES (@id, @session_id, @invocation_id, @timestamp, @event_data) """ self._run_write([(insert_sql, event_params, self._event_param_types())]) def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} - WHERE session_id = @session_id + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id + WHERE s.app_name = @app_name AND s.user_id = @user_id AND e.session_id = @session_id """ + params: dict[str, Any] = {"app_name": app_name, "user_id": user_id, "session_id": session_id} + types: dict[str, Any] = { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "session_id": SPANNER_PARAM_TYPES.STRING, + } if self._shard_count > 1: - sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})" - params: dict[str, Any] = {"session_id": session_id} - types: dict[str, Any] = {"session_id": SPANNER_PARAM_TYPES.STRING} + sql = f"{sql} AND e.shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})" if after_timestamp is not None: - sql = f"{sql} AND timestamp > @after_timestamp" + sql = f"{sql} AND e.timestamp > @after_timestamp" params["after_timestamp"] = after_timestamp types["after_timestamp"] = SPANNER_PARAM_TYPES.TIMESTAMP - sql = f"{sql} ORDER BY timestamp ASC" + sql = f"{sql} ORDER BY e.timestamp ASC" if limit is not None: sql = f"{sql} LIMIT @limit" params["limit"] = limit @@ -424,11 +465,13 @@ def _get_events( rows = self._run_read(sql, params, types) return [ { - "session_id": row[0], - "invocation_id": row[1] or "", - "author": row[2] or "", + "id": row[0], + "session_id": row[1], + "invocation_id": row[2] or "", "timestamp": row[3], "event_data": row[4], + "app_name": row[5], + "user_id": row[6], } for row in rows ] @@ -555,20 +598,23 @@ async def _get_create_sessions_table_sql(self) -> str: async def _get_create_events_table_sql(self) -> str: shard_column = "" - pk = "PRIMARY KEY (session_id, timestamp)" + pk = "PRIMARY KEY (id)" + fk = f"CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE" if self._shard_count > 1: shard_column = f",\n shard_id INT64 AS (MOD(FARM_FINGERPRINT(session_id), {self._shard_count})) STORED" - pk = "PRIMARY KEY (shard_id, session_id, timestamp)" + pk = "PRIMARY KEY (shard_id, id)" + fk = f"CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (shard_id, session_id) REFERENCES {self._session_table}(shard_id, id) ON DELETE CASCADE" options = "" if self._events_table_options: options = f"\nOPTIONS ({self._events_table_options})" return f""" CREATE TABLE {self._events_table} ( + id STRING(128) NOT NULL, session_id STRING(128) NOT NULL, - invocation_id STRING(256) NOT NULL, - author STRING(128) NOT NULL, + invocation_id STRING(256), timestamp TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true), - event_data JSON NOT NULL{shard_column} + event_data JSON NOT NULL{shard_column}, + {fk} ) {pk}{options}{self._events_row_deletion_policy} """ diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py index 1ee8049b2..53579d7b7 100644 --- a/sqlspec/adapters/sqlite/adk/store.py +++ b/sqlspec/adapters/sqlite/adk/store.py @@ -146,11 +146,13 @@ async def create_session( return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": - """Get session by ID. + """Get session. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. renew_for: If positive, touch update_time while reading. @@ -161,12 +163,14 @@ async def get_session( SQLite returns Julian Day (REAL) for timestamps. JSON is parsed from TEXT storage. """ - return await async_(self._get_session)(session_id, renew_for) + return await async_(self._get_session)(app_name, user_id, session_id, renew_for) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. state: New state dictionary (replaces existing state). @@ -175,7 +179,7 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - Updates update_time to current Julian Day. Empty dict is serialized as '{}', never NULL. """ - await async_(self._update_session_state)(session_id, state) + await async_(self._update_session_state)(app_name, user_id, session_id, state) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app, optionally filtered by user. @@ -192,23 +196,24 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis """ return await async_(self._list_sessions)(app_name, user_id) - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and all associated events (cascade). Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. Notes: Foreign key constraint ensures events are cascade-deleted. """ - await async_(self._delete_session)(session_id) + await async_(self._delete_session)(app_name, user_id, session_id) async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session. Args: - event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_data. + event_record: Event record. Notes: Uses Julian Day for timestamp. @@ -219,31 +224,32 @@ async def append_event(self, event_record: EventRecord) -> None: async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically append an event and update session + scoped state.""" return await async_(self._append_event_and_update_state)( - event_record, - session_id, - state, - app_name=app_name, - user_id=user_id, - app_state=app_state, - user_state=user_state, + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. after_timestamp: Only return events after this time. limit: Maximum number of events to return. @@ -255,7 +261,7 @@ async def get_events( Uses index on (session_id, timestamp ASC). Parses event_data TEXT back to dict for event_data field. """ - return await async_(self._get_events)(session_id, after_timestamp, limit) + return await async_(self._get_events)(app_name, user_id, session_id, after_timestamp, limit) async def delete_expired_events(self, before: datetime) -> int: """Delete events older than the given timestamp.""" @@ -355,7 +361,6 @@ async def _get_create_events_table_sql(self) -> str: id TEXT PRIMARY KEY, session_id TEXT NOT NULL, invocation_id TEXT, - author TEXT, timestamp REAL NOT NULL, event_data TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE @@ -475,23 +480,27 @@ def _create_session( id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now ) - def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = None) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Synchronous implementation of get_session.""" sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ try: with self._config.provide_connection() as conn: self._apply_pragmas(conn) if renew_for is not None and self._calculate_expires_at(renew_for) is not None: - update_sql = f"UPDATE {self._session_table} SET update_time = ? WHERE id = ?" - conn.execute(update_sql, (_datetime_to_julian(datetime.now(timezone.utc)), session_id)) + update_sql = f"UPDATE {self._session_table} SET update_time = ? WHERE app_name = ? AND user_id = ? AND id = ?" + conn.execute( + update_sql, (_datetime_to_julian(datetime.now(timezone.utc)), app_name, user_id, session_id) + ) conn.commit() - cursor = conn.execute(sql, (session_id,)) + cursor = conn.execute(sql, (app_name, user_id, session_id)) row = cursor.fetchone() if row is None: @@ -510,7 +519,7 @@ def _get_session(self, session_id: str, renew_for: "int | timedelta | None" = No return None raise - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Synchronous implementation of update_session_state.""" now_julian = _datetime_to_julian(datetime.now(timezone.utc)) state_json = to_json(state) @@ -518,12 +527,12 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non sql = f""" UPDATE {self._session_table} SET state = ?, update_time = ? - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ with self._config.provide_connection() as conn: self._apply_pragmas(conn) - conn.execute(sql, (state_json, now_julian, session_id)) + conn.execute(sql, (state_json, now_julian, app_name, user_id, session_id)) conn.commit() def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionRecord]": @@ -567,13 +576,13 @@ def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionR return [] raise - def _delete_session(self, session_id: str) -> None: + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Synchronous implementation of delete_session.""" - sql = f"DELETE FROM {self._session_table} WHERE id = ?" + sql = f"DELETE FROM {self._session_table} WHERE app_name = ? AND user_id = ? AND id = ?" with self._config.provide_connection() as conn: self._apply_pragmas(conn) - conn.execute(sql, (session_id,)) + conn.execute(sql, (app_name, user_id, session_id)) conn.commit() def _append_event(self, event_record: EventRecord) -> None: @@ -583,23 +592,18 @@ def _append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - id, session_id, invocation_id, author, timestamp, event_data - ) VALUES (?, ?, ?, ?, ?, ?) + id, session_id, invocation_id, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?) """ - import uuid - - event_id = str(uuid.uuid4()) - with self._config.provide_connection() as conn: self._apply_pragmas(conn) conn.execute( sql, ( - event_id, + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], timestamp_julian, event_data_json, ), @@ -609,33 +613,30 @@ def _append_event(self, event_record: EventRecord) -> None: def _append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Synchronous implementation of append_event_and_update_state.""" - import uuid - timestamp_julian = _datetime_to_julian(event_record["timestamp"]) event_data_json = to_json(event_record["event_data"]) now_julian = _datetime_to_julian(datetime.now(timezone.utc)) state_json = to_json(state) - event_id = str(uuid.uuid4()) insert_sql = f""" INSERT INTO {self._events_table} ( - id, session_id, invocation_id, author, timestamp, event_data - ) VALUES (?, ?, ?, ?, ?, ?) + id, session_id, invocation_id, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?) """ update_sql = f""" UPDATE {self._session_table} SET state = ?, update_time = ? - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? RETURNING id, app_name, user_id, state, create_time, update_time """ @@ -660,25 +661,18 @@ def _append_event_and_update_state( conn.execute( insert_sql, ( - event_id, + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], timestamp_julian, event_data_json, ), ) - cursor = conn.execute(update_sql, (state_json, now_julian, session_id)) + cursor = conn.execute(update_sql, (state_json, now_julian, app_name, user_id, session_id)) row = cursor.fetchone() if app_state: - if app_name is None: - msg = "app_name is required when app_state is provided." - raise ValueError(msg) conn.execute(app_upsert_sql, (app_name, to_json(app_state), now_julian)) if user_state: - if app_name is None or user_id is None: - msg = "app_name and user_id are required when user_state is provided." - raise ValueError(msg) conn.execute(user_upsert_sql, (app_name, user_id, to_json(user_state), now_julian)) conn.commit() @@ -696,24 +690,30 @@ def _append_event_and_update_state( ) def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Synchronous implementation of get_events.""" - where_clauses = ["session_id = ?"] - params: list[Any] = [session_id] + where_clauses = ["s.app_name = ?", "s.user_id = ?", "e.session_id = ?"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > ?") + where_clauses.append("e.timestamp > ?") params.append(_datetime_to_julian(after_timestamp)) where_clause = " AND ".join(where_clauses) limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT id, session_id, invocation_id, author, timestamp, event_data - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -724,11 +724,13 @@ def _get_events( return [ EventRecord( + id=row[0], session_id=row[1], invocation_id=row[2], - author=row[3], - timestamp=_julian_to_datetime(row[4]), - event_data=from_json(row[5]) if row[5] else {}, + timestamp=_julian_to_datetime(row[3]), + event_data=from_json(row[4]) if row[4] else {}, + app_name=row[5], + user_id=row[6], ) for row in rows ] diff --git a/sqlspec/extensions/adk/_types.py b/sqlspec/extensions/adk/_types.py index 2864c437b..b53847ba8 100644 --- a/sqlspec/extensions/adk/_types.py +++ b/sqlspec/extensions/adk/_types.py @@ -34,8 +34,10 @@ class EventRecord(TypedDict): automatically captured in ``event_data`` without schema changes. """ + id: str + app_name: str + user_id: str session_id: str invocation_id: str - author: str timestamp: datetime event_data: "dict[str, Any]" diff --git a/sqlspec/extensions/adk/converters.py b/sqlspec/extensions/adk/converters.py index 6ebd82e06..520208c4f 100644 --- a/sqlspec/extensions/adk/converters.py +++ b/sqlspec/extensions/adk/converters.py @@ -103,7 +103,7 @@ def record_to_session(record: SessionRecord, events: "list[EventRecord]") -> "Se # --------------------------------------------------------------------------- -def event_to_record(event: "Event", session_id: str) -> EventRecord: +def event_to_record(event: "Event", app_name: str, user_id: str, session_id: str) -> EventRecord: """Convert ADK Event to database record using full-event JSON storage. The entire Event is serialized into ``event_data`` via Pydantic's @@ -112,15 +112,19 @@ def event_to_record(event: "Event", session_id: str) -> EventRecord: Args: event: ADK Event object. + app_name: Name of the application. + user_id: ID of the user. session_id: ID of the parent session. Returns: EventRecord for database storage. """ return EventRecord( + id=event.id, + app_name=app_name, + user_id=user_id, session_id=session_id, invocation_id=event.invocation_id, - author=event.author, timestamp=datetime.fromtimestamp(event.timestamp, tz=timezone.utc), event_data=event.model_dump(exclude_none=True, mode="json"), ) diff --git a/sqlspec/extensions/adk/service.py b/sqlspec/extensions/adk/service.py index 26c4aaaa0..75891f1a8 100644 --- a/sqlspec/extensions/adk/service.py +++ b/sqlspec/extensions/adk/service.py @@ -125,7 +125,7 @@ async def get_session( Returns: Session object if found, None otherwise. """ - record = await self._store.get_session(session_id, renew_for=renew_for) + record = await self._store.get_session(app_name, user_id, session_id, renew_for=renew_for) if not record: log_with_context( @@ -151,7 +151,9 @@ async def get_session( after_timestamp = datetime.fromtimestamp(config.after_timestamp, tz=timezone.utc) limit = config.num_recent_events - events = await self._store.get_events(session_id=session_id, after_timestamp=after_timestamp, limit=limit) + events = await self._store.get_events( + app_name=app_name, user_id=user_id, session_id=session_id, after_timestamp=after_timestamp, limit=limit + ) log_with_context( logger, logging.DEBUG, @@ -196,7 +198,7 @@ async def delete_session(self, *, app_name: str, user_id: str, session_id: str) user_id: ID of the user. session_id: Session identifier. """ - record = await self._store.get_session(session_id) + record = await self._store.get_session(app_name, user_id, session_id) if not record: log_with_context( @@ -210,7 +212,7 @@ async def delete_session(self, *, app_name: str, user_id: str, session_id: str) ) return - await self._store.delete_session(session_id) + await self._store.delete_session(app_name, user_id, session_id) log_with_context( logger, logging.DEBUG, "adk.session.delete", app_name=app_name, session_id=session_id, deleted=True ) @@ -249,7 +251,9 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": self._apply_temp_state(session, event) event = self._trim_temp_delta_state(event) - event_record = event_to_record(event=event, session_id=session.id) + event_record = event_to_record( + event=event, app_name=session.app_name, user_id=session.user_id, session_id=session.id + ) # Build durable state: current state minus temp keys, plus the # event's state delta (temp keys already stripped by _trim above). @@ -259,7 +263,7 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": app_state, user_state, session_state = split_scoped_state(durable_state) # --- Stale-session detection --- - current_record = await self._store.get_session(session.id) + current_record = await self._store.get_session(session.app_name, session.user_id, session.id) if current_record is None: msg = f"Session {session.id} not found." raise ValueError(msg) @@ -282,10 +286,10 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": # --- Persist event and all scoped state atomically --- updated_record = await self._store.append_event_and_update_state( event_record=event_record, - session_id=session.id, - state=session_state, app_name=session.app_name, user_id=session.user_id, + session_id=session.id, + state=session_state, app_state=app_state or None, user_state=user_state or None, ) diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index ff1fe1d8b..00c0e7057 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -226,11 +226,13 @@ async def create_session( @abstractmethod async def get_session( - self, session_id: str, *, renew_for: "int | timedelta | None" = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None ) -> "SessionRecord | None": - """Get a session by ID. + """Get a session. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. renew_for: If positive, touch the session update timestamp while reading. @@ -240,10 +242,12 @@ async def get_session( raise NotImplementedError @abstractmethod - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. state: New state dictionary. """ @@ -263,10 +267,12 @@ async def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "l raise NotImplementedError @abstractmethod - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete a session and its events. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. """ raise NotImplementedError @@ -284,11 +290,11 @@ async def append_event(self, event_record: "EventRecord") -> None: async def append_event_and_update_state( self, event_record: "EventRecord", + app_name: str, + user_id: str, session_id: str, state: "dict[str, Any]", *, - app_name: "str | None" = None, - user_id: "str | None" = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> "SessionRecord": @@ -307,13 +313,11 @@ async def append_event_and_update_state( Args: event_record: Event record to store. + app_name: Application name for routing scoped-state upserts. + user_id: User identifier for routing user-scoped upserts. session_id: Session identifier whose state should be updated. state: Post-append durable session-scoped state snapshot (``temp:`` keys already stripped by the service layer). - app_name: Application name for routing scoped-state upserts. Required - when ``app_state`` or ``user_state`` is non-empty. - user_id: User identifier for routing user-scoped upserts. Required - when ``user_state`` is non-empty. app_state: App-scoped state delta (``app:*`` keys) to upsert atomically. user_state: User-scoped state delta (``user:*`` keys) to upsert atomically. @@ -328,11 +332,18 @@ async def append_event_and_update_state( @abstractmethod async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. after_timestamp: Only return events after this time. limit: Maximum number of events to return. diff --git a/tests/integration/adapters/_adk_contract_helpers.py b/tests/integration/adapters/_adk_contract_helpers.py index 6a720e759..e064fb260 100644 --- a/tests/integration/adapters/_adk_contract_helpers.py +++ b/tests/integration/adapters/_adk_contract_helpers.py @@ -31,31 +31,38 @@ async def create_session( ) -> SessionRecord: ... async def get_session( - self, session_id: str, *, renew_for: int | timedelta | None = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: int | timedelta | None = None ) -> SessionRecord | None: ... - async def update_session_state(self, session_id: str, state: dict[str, object]) -> None: ... + async def update_session_state( + self, app_name: str, user_id: str, session_id: str, state: dict[str, object] + ) -> None: ... async def list_sessions(self, app_name: str, user_id: str | None = None) -> list[SessionRecord]: ... - async def delete_session(self, session_id: str) -> None: ... + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: ... async def append_event(self, event_record: EventRecord) -> None: ... async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: dict[str, object], *, - app_name: str | None = None, - user_id: str | None = None, app_state: dict[str, object] | None = None, user_state: dict[str, object] | None = None, ) -> SessionRecord: ... async def get_events( - self, session_id: str, after_timestamp: datetime | None = None, limit: int | None = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: datetime | None = None, + limit: int | None = None, ) -> list[EventRecord]: ... async def delete_expired_events(self, before: datetime) -> int: ... @@ -95,19 +102,22 @@ def _contract_key(marker: str, suffix: str) -> str: def _event_record( *, - session_id: str, event_id: str, + app_name: str, + user_id: str, + session_id: str, invocation_id: str, - author: str, timestamp: datetime, event_data: dict[str, object], ) -> EventRecord: data = dict(event_data) data.setdefault("id", event_id) return { + "id": event_id, + "app_name": app_name, + "user_id": user_id, "session_id": session_id, "invocation_id": invocation_id, - "author": author, "timestamp": timestamp, "event_data": data, } @@ -172,10 +182,11 @@ async def assert_session_event_store_contract(store: SessionEventStore, *, marke assert created["state"] == {"created": True} first_event = _event_record( - session_id=session_id, event_id="contract-event-1", + app_name=app_name, + user_id=user_id, + session_id=session_id, invocation_id="contract-inv-1", - author="user", timestamp=base_time, event_data={ "output": {"kind": "text", "value": "captured by full-event JSON"}, @@ -188,15 +199,15 @@ async def assert_session_event_store_contract(store: SessionEventStore, *, marke }, }, ) - updated = await store.append_event_and_update_state(first_event, session_id, {"turn": 1}) + updated = await store.append_event_and_update_state(first_event, app_name, user_id, session_id, {"turn": 1}) assert updated["id"] == session_id assert updated["state"] == {"turn": 1} - fetched = await store.get_session(session_id) + fetched = await store.get_session(app_name, user_id, session_id) assert fetched is not None assert fetched["state"] == {"turn": 1} - stored_events = await store.get_events(session_id) + stored_events = await store.get_events(app_name, user_id, session_id) assert len(stored_events) == 1 assert stored_events[0]["invocation_id"] == "contract-inv-1" first_data = _event_data(stored_events[0]) @@ -211,34 +222,38 @@ async def assert_session_event_store_contract(store: SessionEventStore, *, marke await store.append_event( _event_record( - session_id=session_id, event_id="contract-event-2", + app_name=app_name, + user_id=user_id, + session_id=session_id, invocation_id="contract-inv-2", - author="model", timestamp=base_time + timedelta(seconds=1), event_data={"content": {"parts": [{"text": "second"}]}}, ) ) await store.append_event( _event_record( - session_id=session_id, event_id="contract-event-3", + app_name=app_name, + user_id=user_id, + session_id=session_id, invocation_id="contract-inv-3", - author="model", timestamp=base_time + timedelta(seconds=2), event_data={"content": {"parts": [{"text": "third"}]}}, ) ) - filtered = await store.get_events(session_id, after_timestamp=base_time + timedelta(milliseconds=500), limit=1) + filtered = await store.get_events( + app_name, user_id, session_id, after_timestamp=base_time + timedelta(milliseconds=500), limit=1 + ) assert [event["invocation_id"] for event in filtered] == ["contract-inv-2"] listed = await store.list_sessions(app_name, user_id) assert any(record["id"] == session_id for record in listed) - await store.delete_session(session_id) - assert await store.get_session(session_id) is None - assert await store.get_events(session_id) == [] + await store.delete_session(app_name, user_id, session_id) + assert await store.get_session(app_name, user_id, session_id) is None + assert await store.get_events(app_name, user_id, session_id) == [] async def assert_session_get_session_renewal_contract(store: SessionEventStore, *, marker: str) -> None: @@ -252,7 +267,7 @@ async def assert_session_get_session_renewal_contract(store: SessionEventStore, await asyncio.sleep(0.02) before_renewal = datetime.now(timezone.utc) - timedelta(seconds=2) - renewed = await store.get_session(session_id, renew_for=timedelta(hours=1)) + renewed = await store.get_session(app_name, user_id, session_id, renew_for=timedelta(hours=1)) after_renewal = datetime.now(timezone.utc) + timedelta(seconds=2) assert renewed is not None @@ -298,7 +313,7 @@ async def assert_session_scoped_state_contract(store: SessionEventStore, *, mark ) await service.append_event(session_a, event) - raw_session = await store.get_session(session_a.id) + raw_session = await store.get_session(app_name, user_id, session_a.id) assert raw_session is not None assert raw_session["state"] == {"session_seed": "a", "turn": 1} assert await store.get_app_state(app_name) == {"app:counter": 1} @@ -335,20 +350,21 @@ async def assert_session_atomic_scoped_write_contract(store: SessionEventStore, await store.create_session(session_id, app_name, user_id, {"initial": 0}) event = _event_record( - session_id=session_id, event_id="atomic-event-1", + app_name=app_name, + user_id=user_id, + session_id=session_id, invocation_id="atomic-inv-1", - author="user", timestamp=base_time, event_data={"actions": {"state_delta": {"app:counter": 1, "user:theme": "dark", "turn": 1}}}, ) updated = await store.append_event_and_update_state( event, + app_name, + user_id, session_id, {"turn": 1}, - app_name=app_name, - user_id=user_id, app_state={"app:counter": 1}, user_state={"user:theme": "dark"}, ) @@ -357,20 +373,21 @@ async def assert_session_atomic_scoped_write_contract(store: SessionEventStore, assert updated["id"] == session_id assert await store.get_app_state(app_name) == {"app:counter": 1} assert await store.get_user_state(app_name, user_id) == {"user:theme": "dark"} - stored_events = await store.get_events(session_id) + stored_events = await store.get_events(app_name, user_id, session_id) assert any(record["invocation_id"] == "atomic-inv-1" for record in stored_events) await store.create_session(no_scope_session_id, app_name, user_id, {"phase": 0}) no_scope_event = _event_record( - session_id=no_scope_session_id, event_id="atomic-event-2", + app_name=app_name, + user_id=user_id, + session_id=no_scope_session_id, invocation_id="atomic-inv-2", - author="model", timestamp=base_time + timedelta(seconds=1), event_data={"content": {"parts": [{"text": "no scope delta"}]}}, ) no_scope_update = await store.append_event_and_update_state( - no_scope_event, no_scope_session_id, {"phase": 1}, app_name=app_name, user_id=user_id + no_scope_event, app_name, user_id, no_scope_session_id, {"phase": 1} ) assert no_scope_update["state"] == {"phase": 1} # Skipped scoped writes leave existing app/user state untouched. @@ -401,7 +418,7 @@ async def assert_session_temp_state_not_persisted(store: SessionEventStore, *, m ) await service.append_event(session, event) - raw_session = await store.get_session(session_id) + raw_session = await store.get_session(app_name, user_id, session_id) assert raw_session is not None assert "temp:scratch" not in raw_session["state"] assert "temp:create_seed" not in raw_session["state"] @@ -428,22 +445,23 @@ async def assert_session_empty_state_roundtrip(store: SessionEventStore, *, mark created = await store.create_session(session_id, app_name, user_id, {}) assert created["state"] == {} - fetched = await store.get_session(session_id) + fetched = await store.get_session(app_name, user_id, session_id) assert fetched is not None assert fetched["state"] == {} event = _event_record( - session_id=session_id, event_id="empty-event-1", + app_name=app_name, + user_id=user_id, + session_id=session_id, invocation_id="empty-inv-1", - author="user", timestamp=base_time, event_data={"content": {"parts": [{"text": "no state delta"}]}}, ) - updated = await store.append_event_and_update_state(event, session_id, {}, app_name=app_name, user_id=user_id) + updated = await store.append_event_and_update_state(event, app_name, user_id, session_id, {}) assert updated["state"] == {} - after = await store.get_session(session_id) + after = await store.get_session(app_name, user_id, session_id) assert after is not None assert after["state"] == {} @@ -464,15 +482,16 @@ async def assert_session_sibling_app_isolation(store: SessionEventStore, *, mark await store.create_session(session_b, app_b, user_id, {}) event = _event_record( - session_id=session_a, event_id="sibling-app-event-1", + app_name=app_a, + user_id=user_id, + session_id=session_a, invocation_id="sibling-app-inv-1", - author="user", timestamp=base_time, event_data={"actions": {"state_delta": {"app:counter": 7, "turn": 1}}}, ) await store.append_event_and_update_state( - event, session_a, {"turn": 1}, app_name=app_a, user_id=user_id, app_state={"app:counter": 7} + event, app_a, user_id, session_a, {"turn": 1}, app_state={"app:counter": 7} ) assert await store.get_app_state(app_a) == {"app:counter": 7} @@ -492,15 +511,16 @@ async def assert_session_sibling_user_isolation(store: SessionEventStore, *, mar await store.create_session(session_b, app_name, user_b, {}) event = _event_record( - session_id=session_a, event_id="sibling-user-event-1", + app_name=app_name, + user_id=user_a, + session_id=session_a, invocation_id="sibling-user-inv-1", - author="user", timestamp=base_time, event_data={"actions": {"state_delta": {"user:pref": "dark", "turn": 1}}}, ) await store.append_event_and_update_state( - event, session_a, {"turn": 1}, app_name=app_name, user_id=user_a, user_state={"user:pref": "dark"} + event, app_name, user_a, session_a, {"turn": 1}, user_state={"user:pref": "dark"} ) assert await store.get_user_state(app_name, user_a) == {"user:pref": "dark"} @@ -514,16 +534,16 @@ async def assert_session_table_lifecycle_contract(store: SessionEventStore, *, m session_id = _contract_key(marker, "lifecycle-session") await store.create_session(session_id, app_name, user_id, {"phase": "before"}) - assert await store.get_session(session_id) is not None + assert await store.get_session(app_name, user_id, session_id) is not None await store.recreate_tables() - assert await store.get_session(session_id) is None + assert await store.get_session(app_name, user_id, session_id) is None recreated = await store.create_session(session_id, app_name, user_id, {"phase": "after"}) assert recreated["state"] == {"phase": "after"} await store.drop_tables() - assert await store.get_session(session_id) is None + assert await store.get_session(app_name, user_id, session_id) is None await store.drop_tables() await store.recreate_tables() @@ -539,20 +559,22 @@ async def assert_session_event_cleanup_contract(store: SessionEventStore, *, mar await store.create_session(session_id, app_name, user_id, {"cleanup": True}) await store.append_event( _event_record( - session_id=session_id, event_id="cleanup-old-event", + app_name=app_name, + user_id=user_id, + session_id=session_id, invocation_id="cleanup-old", - author="user", timestamp=old_time, event_data={"content": {"parts": [{"text": "old"}]}}, ) ) await store.append_event( _event_record( - session_id=session_id, event_id="cleanup-new-event", + app_name=app_name, + user_id=user_id, + session_id=session_id, invocation_id="cleanup-new", - author="user", timestamp=new_time, event_data={"content": {"parts": [{"text": "new"}]}}, ) @@ -560,13 +582,13 @@ async def assert_session_event_cleanup_contract(store: SessionEventStore, *, mar deleted_events = await store.delete_expired_events(datetime(2026, 5, 5, tzinfo=timezone.utc)) assert deleted_events == 1 - remaining_events = await store.get_events(session_id) + remaining_events = await store.get_events(app_name, user_id, session_id) assert [event["invocation_id"] for event in remaining_events] == ["cleanup-new"] deleted_sessions = await store.delete_idle_sessions(datetime(2100, 1, 1, tzinfo=timezone.utc)) assert deleted_sessions == 1 - assert await store.get_session(session_id) is None - assert await store.get_events(session_id) == [] + assert await store.get_session(app_name, user_id, session_id) is None + assert await store.get_events(app_name, user_id, session_id) == [] async def assert_memory_store_contract(store: MemoryStore, *, marker: str) -> None: diff --git a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py index 02c4002ed..2e978a808 100644 --- a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py @@ -57,7 +57,7 @@ async def test_aiosqlite_session_empty_state_round_trip(tmp_path: Path) -> None: config, store = await _build_store(tmp_path) try: created = await store.create_session("session-empty", "app", "user", {}) - fetched = await store.get_session("session-empty") + fetched = await store.get_session("app", "user", "session-empty") assert created["state"] == {} assert fetched is not None @@ -164,16 +164,18 @@ async def test_aiosqlite_append_event_and_update_state_is_atomic_contract(tmp_pa await store.create_session(session_id, "app", "user", {}) event: EventRecord = { + "id": "event-1", + "app_name": "app", + "user_id": "user", "session_id": session_id, "invocation_id": "inv-1", - "author": "user", "timestamp": datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc), "event_data": {"id": "event-1", "content": {"parts": [{"text": "hello"}]}}, } - await store.append_event_and_update_state(event, session_id, {"turn": 1}) + await store.append_event_and_update_state(event, "app", "user", session_id, {"turn": 1}) - session = await store.get_session(session_id) - events = await store.get_events(session_id) + session = await store.get_session("app", "user", session_id) + events = await store.get_events("app", "user", session_id) assert session is not None assert session["state"] == {"turn": 1} @@ -190,10 +192,10 @@ async def test_aiosqlite_reads_return_empty_when_tables_missing(tmp_path: Path) config = AiosqliteConfig(connection_config={"database": str(db_path)}) store = AiosqliteADKStore(config) try: - assert await store.get_session("missing") is None + assert await store.get_session("app", "user", "missing") is None assert await store.list_sessions("app") == [] assert await store.list_sessions("app", "user") == [] - assert await store.get_events("session-x") == [] + assert await store.get_events("app", "user", "session-x") == [] finally: await config.close_pool() @@ -208,15 +210,19 @@ async def test_aiosqlite_get_events_filters_by_timestamp_and_limit(tmp_path: Pat for index in range(3): event: EventRecord = { + "id": f"event-{index}", + "app_name": "app", + "user_id": "user", "session_id": session_id, "invocation_id": f"inv-{index}", - "author": "user", "timestamp": base + timedelta(seconds=index), "event_data": {"id": f"event-{index}"}, } await store.append_event(event) - events = await store.get_events(session_id, after_timestamp=base + timedelta(milliseconds=500), limit=1) + events = await store.get_events( + "app", "user", session_id, after_timestamp=base + timedelta(milliseconds=500), limit=1 + ) assert len(events) == 1 assert events[0]["invocation_id"] == "inv-1" diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index ba53ed16b..eb9803e0c 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -127,7 +127,7 @@ async def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: assert isinstance(created_session["create_time"], datetime) assert isinstance(created_session["update_time"], datetime) - retrieved_session = await duckdb_adk_store.get_session(session_id) + retrieved_session = await duckdb_adk_store.get_session(app_name, user_id, session_id) assert retrieved_session is not None assert retrieved_session["id"] == session_id assert retrieved_session["state"] == state @@ -135,7 +135,7 @@ async def test_create_and_get_session(duckdb_adk_store: DuckdbADKStore) -> None: async def test_get_nonexistent_session(duckdb_adk_store: DuckdbADKStore) -> None: """Test getting a non-existent session returns None.""" - result = await duckdb_adk_store.get_session("nonexistent-session") + result = await duckdb_adk_store.get_session("test-app", "user-001", "nonexistent-session") assert result is None @@ -149,13 +149,13 @@ async def test_update_session_state(duckdb_adk_store: DuckdbADKStore) -> None: session_id=session_id, app_name="test-app", user_id="user-002", state=initial_state ) - session_before = await duckdb_adk_store.get_session(session_id) + session_before = await duckdb_adk_store.get_session("test-app", "user-002", session_id) assert session_before is not None assert session_before["state"] == initial_state - await duckdb_adk_store.update_session_state(session_id, updated_state) + await duckdb_adk_store.update_session_state("test-app", "user-002", session_id, updated_state) - session_after = await duckdb_adk_store.get_session(session_id) + session_after = await duckdb_adk_store.get_session("test-app", "user-002", session_id) assert session_after is not None assert session_after["state"] == updated_state assert session_after["update_time"] >= session_before["update_time"] @@ -191,11 +191,11 @@ async def test_delete_session(duckdb_adk_store: DuckdbADKStore) -> None: session_id = "session-to-delete" await duckdb_adk_store.create_session(session_id, "test-app", "user-004", {"data": "test"}) - assert await duckdb_adk_store.get_session(session_id) is not None + assert await duckdb_adk_store.get_session("test-app", "user-004", session_id) is not None - await duckdb_adk_store.delete_session(session_id) + await duckdb_adk_store.delete_session("test-app", "user-004", session_id) - assert await duckdb_adk_store.get_session(session_id) is None + assert await duckdb_adk_store.get_session("test-app", "user-004", session_id) is None async def test_delete_session_cascade_events(duckdb_adk_store: DuckdbADKStore) -> None: @@ -204,9 +204,11 @@ async def test_delete_session_cascade_events(duckdb_adk_store: DuckdbADKStore) - await duckdb_adk_store.create_session(session_id, "test-app", "user-005", {"data": "test"}) event_record: EventRecord = { + "id": "event-001", + "app_name": "test-app", + "user_id": "user-005", "session_id": session_id, "invocation_id": "", - "author": "user", "timestamp": datetime.now(timezone.utc), "event_data": { "id": "event-001", @@ -217,13 +219,13 @@ async def test_delete_session_cascade_events(duckdb_adk_store: DuckdbADKStore) - } await duckdb_adk_store.append_event(event_record) - events = await duckdb_adk_store.get_events(session_id) + events = await duckdb_adk_store.get_events("test-app", "user-005", session_id) assert len(events) == 1 - await duckdb_adk_store.delete_session(session_id) + await duckdb_adk_store.delete_session("test-app", "user-005", session_id) - assert await duckdb_adk_store.get_session(session_id) is None - events_after = await duckdb_adk_store.get_events(session_id) + assert await duckdb_adk_store.get_session("test-app", "user-005", session_id) is None + events_after = await duckdb_adk_store.get_events("test-app", "user-005", session_id) assert len(events_after) == 0 @@ -236,18 +238,19 @@ async def test_create_event(duckdb_adk_store: DuckdbADKStore) -> None: content = {"text": "Test message", "role": "user"} event_record: EventRecord = { + "id": "event-002", + "app_name": "test-app", + "user_id": "user-006", "session_id": session_id, "invocation_id": "", - "author": "user", "timestamp": timestamp, "event_data": {"id": "event-002", "content": content, "app_name": "test-app", "user_id": "user-006"}, } await duckdb_adk_store.append_event(event_record) - events = await duckdb_adk_store.get_events(session_id) + events = await duckdb_adk_store.get_events("test-app", "user-006", session_id) assert len(events) == 1 assert events[0]["session_id"] == session_id - assert events[0]["author"] == "user" # Content is stored inside event_data event_data = ( @@ -262,16 +265,20 @@ async def test_list_events(duckdb_adk_store: DuckdbADKStore) -> None: await duckdb_adk_store.create_session(session_id, "test-app", "user-007", {}) event1: EventRecord = { + "id": "event-1", + "app_name": "test-app", + "user_id": "user-007", "session_id": session_id, "invocation_id": "", - "author": "user", "timestamp": datetime.now(timezone.utc), "event_data": {"id": "event-1", "content": {"message": "First"}, "app_name": "test-app", "user_id": "user-007"}, } event2: EventRecord = { + "id": "event-2", + "app_name": "test-app", + "user_id": "user-007", "session_id": session_id, "invocation_id": "", - "author": "assistant", "timestamp": datetime.now(timezone.utc), "event_data": { "id": "event-2", @@ -283,11 +290,9 @@ async def test_list_events(duckdb_adk_store: DuckdbADKStore) -> None: await duckdb_adk_store.append_event(event1) await duckdb_adk_store.append_event(event2) - events = await duckdb_adk_store.get_events(session_id) + events = await duckdb_adk_store.get_events("test-app", "user-007", session_id) assert len(events) == 2 - assert events[0]["author"] == "user" - assert events[1]["author"] == "assistant" assert events[0]["timestamp"] <= events[1]["timestamp"] @@ -296,7 +301,7 @@ async def test_list_events_empty(duckdb_adk_store: DuckdbADKStore) -> None: session_id = "session-no-events" await duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) - events = await duckdb_adk_store.get_events(session_id) + events = await duckdb_adk_store.get_events("test-app", "user-008", session_id) assert events == [] @@ -306,9 +311,11 @@ async def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> N await duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) event_record: EventRecord = { + "id": "event-full", + "app_name": "test-app", + "user_id": "user-008", "session_id": session_id, "invocation_id": "inv-123", - "author": "assistant", "timestamp": datetime.now(timezone.utc), "event_data": { "id": "event-full", @@ -325,7 +332,7 @@ async def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> N } await duckdb_adk_store.append_event(event_record) - events = await duckdb_adk_store.get_events(session_id) + events = await duckdb_adk_store.get_events("test-app", "user-008", session_id) assert len(events) == 1 # The 5-key record has invocation_id as a top-level indexed column @@ -351,23 +358,29 @@ async def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> t3 = datetime.now(timezone.utc) ev_middle: EventRecord = { + "id": "event-middle", + "app_name": "test-app", + "user_id": "user-009", "session_id": session_id, "invocation_id": "", - "author": "", "timestamp": t2, "event_data": {"id": "event-middle", "app_name": "test-app", "user_id": "user-009"}, } ev_last: EventRecord = { + "id": "event-last", + "app_name": "test-app", + "user_id": "user-009", "session_id": session_id, "invocation_id": "", - "author": "", "timestamp": t3, "event_data": {"id": "event-last", "app_name": "test-app", "user_id": "user-009"}, } ev_first: EventRecord = { + "id": "event-first", + "app_name": "test-app", + "user_id": "user-009", "session_id": session_id, "invocation_id": "", - "author": "", "timestamp": t1, "event_data": {"id": "event-first", "app_name": "test-app", "user_id": "user-009"}, } @@ -376,7 +389,7 @@ async def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> await duckdb_adk_store.append_event(ev_last) await duckdb_adk_store.append_event(ev_first) - events = await duckdb_adk_store.get_events(session_id) + events = await duckdb_adk_store.get_events("test-app", "user-009", session_id) assert len(events) == 3 # Events should be ordered by timestamp ASC @@ -402,7 +415,7 @@ async def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) await duckdb_adk_store.create_session(session_id, "test-app", "user-010", complex_state) - session = await duckdb_adk_store.get_session(session_id) + session = await duckdb_adk_store.get_session("test-app", "user-010", session_id) assert session is not None assert session["state"] == complex_state assert session["state"]["user"]["preferences"]["theme"] == "dark" @@ -414,7 +427,7 @@ async def test_empty_state(duckdb_adk_store: DuckdbADKStore) -> None: session_id = "session-empty-state" await duckdb_adk_store.create_session(session_id, "test-app", "user-011", {}) - session = await duckdb_adk_store.get_session(session_id) + session = await duckdb_adk_store.get_session("test-app", "user-011", session_id) assert session is not None assert session["state"] == {} @@ -426,13 +439,13 @@ async def test_table_not_found_handling(tmp_path: Path) -> None: config = DuckDBConfig(connection_config={"database": str(db_path)}) store = DuckdbADKStore(config) - result = await store.get_session("nonexistent") + result = await store.get_session("app", "user", "nonexistent") assert result is None sessions = await store.list_sessions("app", "user") assert sessions == [] - events = await store.get_events("session") + events = await store.get_events("app", "user", "session") assert events == [] finally: if db_path.exists(): @@ -445,15 +458,17 @@ async def test_event_data_round_trip(duckdb_adk_store: DuckdbADKStore) -> None: await duckdb_adk_store.create_session(session_id, "test-app", "user-012", {}) event_record: EventRecord = { + "id": "event-json", + "app_name": "test-app", + "user_id": "user-012", "session_id": session_id, "invocation_id": "", - "author": "system", "timestamp": datetime.now(timezone.utc), "event_data": {"id": "event-json", "content": {"data": "value"}, "app_name": "test-app", "user_id": "user-012"}, } await duckdb_adk_store.append_event(event_record) - events = await duckdb_adk_store.get_events(session_id) + events = await duckdb_adk_store.get_events("test-app", "user-012", session_id) assert len(events) == 1 event_data = ( json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] @@ -467,12 +482,14 @@ async def test_concurrent_session_updates(duckdb_adk_store: DuckdbADKStore) -> N await duckdb_adk_store.create_session(session_id, "test-app", "user-013", {"counter": 0}) for i in range(10): - session = await duckdb_adk_store.get_session(session_id) + session = await duckdb_adk_store.get_session("test-app", "user-013", session_id) assert session is not None current_counter = session["state"]["counter"] - await duckdb_adk_store.update_session_state(session_id, {"counter": current_counter + 1}) + await duckdb_adk_store.update_session_state( + "test-app", "user-013", session_id, {"counter": current_counter + 1} + ) - final_session = await duckdb_adk_store.get_session(session_id) + final_session = await duckdb_adk_store.get_session("test-app", "user-013", session_id) assert final_session is not None assert final_session["state"]["counter"] == 10 @@ -638,7 +655,7 @@ async def test_owner_id_column_without_value(tmp_path: Path) -> None: assert session["id"] == "session-no-fk" - retrieved = await store.get_session("session-no-fk") + retrieved = await store.get_session("test-app", "user-001", "session-no-fk") assert retrieved is not None finally: if db_path.exists(): diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_owner_id_column.py b/tests/integration/adapters/sqlite/extensions/adk/test_owner_id_column.py index 0d69a5f61..6d6c67766 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_owner_id_column.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_owner_id_column.py @@ -126,7 +126,7 @@ async def test_owner_id_column_integer_reference( assert isinstance(session["create_time"], datetime) assert isinstance(session["update_time"], datetime) - retrieved = await store.get_session(session_id) + retrieved = await store.get_session(app_name, user_id, session_id) assert retrieved is not None assert retrieved["id"] == session_id assert retrieved["state"] == initial_state @@ -152,7 +152,7 @@ async def test_owner_id_column_text_reference( assert session["id"] == session_id assert session["state"] == initial_state - retrieved = await store.get_session(session_id) + retrieved = await store.get_session(app_name, user_id, session_id) assert retrieved is not None assert retrieved["id"] == session_id @@ -175,7 +175,7 @@ async def test_owner_id_column_cascade_delete( await store.create_session(session_id, app_name, user_id, initial_state, owner_id=tenant_id) - retrieved_before = await store.get_session(session_id) + retrieved_before = await store.get_session(app_name, user_id, session_id) assert retrieved_before is not None with sqlite_config.provide_connection() as conn: @@ -183,7 +183,7 @@ async def test_owner_id_column_cascade_delete( conn.execute("DELETE FROM tenants WHERE id = ?", (tenant_id,)) conn.commit() - retrieved_after = await store.get_session(session_id) + retrieved_after = await store.get_session(app_name, user_id, session_id) assert retrieved_after is None @@ -260,7 +260,7 @@ async def test_without_owner_id_column( assert session["id"] == session_id assert session["state"] == initial_state - retrieved = await store.get_session(session_id) + retrieved = await store.get_session(app_name, user_id, session_id) assert retrieved is not None assert retrieved["id"] == session_id @@ -310,8 +310,8 @@ async def test_multi_tenant_isolation( await store.create_session(session1_id, app_name, user_id, initial_state, owner_id=tenant1_id) await store.create_session(session2_id, app_name, user_id, {"data": "tenant2"}, owner_id=tenant2_id) - session1 = await store.get_session(session1_id) - session2 = await store.get_session(session2_id) + session1 = await store.get_session(app_name, user_id, session1_id) + session2 = await store.get_session(app_name, user_id, session2_id) assert session1 is not None assert session2 is not None @@ -323,8 +323,8 @@ async def test_multi_tenant_isolation( conn.execute("DELETE FROM tenants WHERE id = ?", (tenant1_id,)) conn.commit() - session1_after = await store.get_session(session1_id) - session2_after = await store.get_session(session2_id) + session1_after = await store.get_session(app_name, user_id, session1_id) + session2_after = await store.get_session(app_name, user_id, session2_id) assert session1_after is None assert session2_after is not None @@ -382,5 +382,5 @@ async def test_owner_id_with_default_value( session = await store.create_session(session_id, app_name, user_id, initial_state) assert session["id"] == session_id - retrieved = await store.get_session(session_id) + retrieved = await store.get_session(app_name, user_id, session_id) assert retrieved is not None diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_store.py index 51e4e1834..663651c8a 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_store.py @@ -37,7 +37,7 @@ async def test_sqlite_session_empty_state_round_trip(tmp_path: Path) -> None: config, store = await _build_store(tmp_path) try: created = await store.create_session("session-empty", "app", "user", {}) - fetched = await store.get_session("session-empty") + fetched = await store.get_session("app", "user", "session-empty") assert created["state"] == {} assert fetched is not None @@ -144,16 +144,18 @@ async def test_sqlite_append_event_and_update_state_is_atomic_contract(tmp_path: await store.create_session(session_id, "app", "user", {}) event: EventRecord = { + "id": "event-1", + "app_name": "app", + "user_id": "user", "session_id": session_id, "invocation_id": "inv-1", - "author": "user", "timestamp": datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc), "event_data": {"id": "event-1", "content": {"parts": [{"text": "hello"}]}}, } - await store.append_event_and_update_state(event, session_id, {"turn": 1}) + await store.append_event_and_update_state(event, "app", "user", session_id, {"turn": 1}) - session = await store.get_session(session_id) - events = await store.get_events(session_id) + session = await store.get_session("app", "user", session_id) + events = await store.get_events("app", "user", session_id) assert session is not None assert session["state"] == {"turn": 1} @@ -170,10 +172,10 @@ async def test_sqlite_reads_return_empty_when_tables_missing(tmp_path: Path) -> config = SqliteConfig(connection_config={"database": str(db_path)}) store = SqliteADKStore(config) try: - assert await store.get_session("missing") is None + assert await store.get_session("app", "user", "missing") is None assert await store.list_sessions("app") == [] assert await store.list_sessions("app", "user") == [] - assert await store.get_events("session-x") == [] + assert await store.get_events("app", "user", "session-x") == [] finally: config.close_pool() @@ -188,15 +190,19 @@ async def test_sqlite_get_events_filters_by_timestamp_and_limit(tmp_path: Path) for index in range(3): event: EventRecord = { + "id": f"event-{index}", + "app_name": "app", + "user_id": "user", "session_id": session_id, "invocation_id": f"inv-{index}", - "author": "user", "timestamp": base + timedelta(seconds=index), "event_data": {"id": f"event-{index}"}, } await store.append_event(event) - events = await store.get_events(session_id, after_timestamp=base + timedelta(milliseconds=500), limit=1) + events = await store.get_events( + "app", "user", session_id, after_timestamp=base + timedelta(milliseconds=500), limit=1 + ) assert len(events) == 1 assert events[0]["invocation_id"] == "inv-1" diff --git a/tests/unit/adapters/test_psycopg/test_adk_store.py b/tests/unit/adapters/test_psycopg/test_adk_store.py index 6320fac44..6f5ca672f 100644 --- a/tests/unit/adapters/test_psycopg/test_adk_store.py +++ b/tests/unit/adapters/test_psycopg/test_adk_store.py @@ -72,18 +72,21 @@ def test_sync_append_event_inserts_without_session_update() -> None: """append_event must insert a single event without writing session state.""" store, cursor, connection = _build_store() event_record = { + "id": "event-1", "session_id": "session-1", "invocation_id": "", - "author": "assistant", "timestamp": datetime.now(timezone.utc), "event_data": {"id": "event-1"}, + "app_name": "app-1", + "user_id": "user-1", } store._append_event(event_record) # type: ignore[arg-type] assert len(cursor.execute_calls) == 1 _, params = cursor.execute_calls[0] - assert params[0] == "session-1" + assert params[0] == "event-1" + assert params[1] == "session-1" assert isinstance(params[4], Jsonb) assert connection.commit_called @@ -93,18 +96,20 @@ def test_sync_get_events_passes_after_timestamp_and_limit() -> None: base_time = datetime(2026, 1, 1, tzinfo=timezone.utc) rows = [ { + "id": "event-1", "session_id": "session-1", "invocation_id": "", - "author": "assistant", "timestamp": base_time, "event_data": {"id": "event-2"}, + "app_name": "app-1", + "user_id": "user-1", } ] store, cursor, _ = _build_store(rows) - result = store._get_events("session-1", after_timestamp=base_time, limit=1) + result = store._get_events("app-1", "user-1", "session-1", after_timestamp=base_time, limit=1) assert len(cursor.execute_calls) == 1 _, params = cursor.execute_calls[0] - assert params == ("session-1", base_time, 1) + assert params == ("app-1", "user-1", "session-1", base_time, 1) assert result[0]["event_data"]["id"] == "event-2" diff --git a/tests/unit/adapters/test_spanner/test_adk_store.py b/tests/unit/adapters/test_spanner/test_adk_store.py index 50d604091..8bbb116ee 100644 --- a/tests/unit/adapters/test_spanner/test_adk_store.py +++ b/tests/unit/adapters/test_spanner/test_adk_store.py @@ -19,9 +19,11 @@ def test_insert_event_preserves_event_record_timestamp() -> None: store = SpannerSyncADKStore(_mock_config()) timestamp = datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc) event: EventRecord = { + "id": "event-1", + "app_name": "app", + "user_id": "u1", "session_id": "session-1", "invocation_id": "inv-1", - "author": "user", "timestamp": timestamp, "event_data": {"id": "event-1"}, } @@ -41,13 +43,14 @@ async def test_append_event_and_update_state_preserves_event_record_timestamp() store = SpannerSyncADKStore(_mock_config()) timestamp = datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc) event: EventRecord = { + "id": "event-1", + "app_name": "app", + "user_id": "u1", "session_id": "session-1", "invocation_id": "inv-1", - "author": "user", "timestamp": timestamp, "event_data": {"id": "event-1"}, } - # Stub the post-write SELECT — the contract requires returning the refreshed record. fake_record = { "id": "session-1", "app_name": "app", @@ -58,7 +61,7 @@ async def test_append_event_and_update_state_preserves_event_record_timestamp() } with patch.object(store, "_run_write") as run_write, patch.object(store, "_get_session", return_value=fake_record): - returned = await store.append_event_and_update_state(event, "session-1", {"turn": 1}) + returned = await store.append_event_and_update_state(event, "app", "u1", "session-1", {"turn": 1}) event_sql, event_params, _event_types = run_write.call_args.args[0][0] update_sql, _state_params, _state_types = run_write.call_args.args[0][1] diff --git a/tests/unit/extensions/test_adk/test_converters.py b/tests/unit/extensions/test_adk/test_converters.py index b7e87770d..f7d5018a6 100644 --- a/tests/unit/extensions/test_adk/test_converters.py +++ b/tests/unit/extensions/test_adk/test_converters.py @@ -220,40 +220,39 @@ def test_merge_scoped_state_does_not_mutate_session_state() -> None: # --------------------------------------------------------------------------- -def test_event_to_record_only_5_keys() -> None: - """EventRecord has exactly session_id, invocation_id, author, timestamp, event_data.""" +def test_event_to_record_keys() -> None: + """EventRecord has exactly the expected keys.""" event = _make_event() - record = event_to_record(event, "session-1") - assert set(record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_data"} + record = event_to_record(event, "test-app", "test-user", "session-1") + assert set(record.keys()) == {"id", "app_name", "user_id", "session_id", "invocation_id", "timestamp", "event_data"} -def test_event_to_record_signature_two_args_only() -> None: - """event_to_record raises TypeError if called with extra positional args (old 4-arg signature).""" +def test_event_to_record_signature_four_args() -> None: + """event_to_record raises TypeError if called with wrong number of args.""" event = _make_event() with pytest.raises(TypeError): - event_to_record(event, "session-1", "app-name", "user-id") # type: ignore[call-arg] + event_to_record(event, "session-1") # type: ignore[call-arg] def test_event_to_record_session_id_stored_correctly() -> None: """session_id in the record matches the argument passed.""" event = _make_event(invocation_id="inv-abc", author="model") - record = event_to_record(event, "my-session-id") + record = event_to_record(event, "test-app", "test-user", "my-session-id") assert record["session_id"] == "my-session-id" def test_event_to_record_indexed_fields_match_event() -> None: - """Indexed scalar columns (invocation_id, author, timestamp) match the source event.""" + """Indexed scalar columns (invocation_id, timestamp) match the source event.""" event = _make_event(invocation_id="inv-xyz", author="tool") - record = event_to_record(event, "s1") + record = event_to_record(event, "test-app", "test-user", "s1") assert record["invocation_id"] == "inv-xyz" - assert record["author"] == "tool" assert isinstance(record["timestamp"], datetime) def test_event_to_record_event_data_matches_model_dump() -> None: """event_data in the record equals event.model_dump(exclude_none=True, mode='json').""" event = _make_event(text="hello", state_delta={"key": "val"}, custom_metadata={"foo": "bar"}) - record = event_to_record(event, "s1") + record = event_to_record(event, "test-app", "test-user", "s1") expected_json = event.model_dump(exclude_none=True, mode="json") assert record["event_data"] == expected_json @@ -261,16 +260,15 @@ def test_event_to_record_event_data_matches_model_dump() -> None: def test_event_to_record_event_data_is_dict() -> None: """event_data field is a plain dict (not bytes, not string).""" event = _make_event() - record = event_to_record(event, "s1") + record = event_to_record(event, "test-app", "test-user", "s1") assert isinstance(record["event_data"], dict) def test_event_to_record_actions_in_event_data_is_structured() -> None: """Actions are stored as structured JSON dict in event_data, not as raw bytes.""" event = _make_event(state_delta={"x": "y"}) - record = event_to_record(event, "s1") + record = event_to_record(event, "test-app", "test-user", "s1") event_data = record["event_data"] - # actions should be a dict in the JSON blob if "actions" in event_data: assert isinstance(event_data["actions"], dict) @@ -278,7 +276,7 @@ def test_event_to_record_actions_in_event_data_is_structured() -> None: def test_event_to_record_timestamp_is_datetime() -> None: """timestamp column is a datetime object with timezone.""" event = _make_event() - record = event_to_record(event, "s1") + record = event_to_record(event, "test-app", "test-user", "s1") assert isinstance(record["timestamp"], datetime) assert record["timestamp"].tzinfo is not None @@ -291,7 +289,7 @@ def test_event_to_record_timestamp_is_datetime() -> None: def test_record_to_event_full_roundtrip_basic() -> None: """Event -> record -> Event produces an identical object for basic fields.""" original = _make_event(event_id="evt-rt", invocation_id="inv-rt", author="model") - record = event_to_record(original, "s1") + record = event_to_record(original, "test-app", "test-user", "s1") restored = record_to_event(record) assert restored.id == original.id @@ -302,7 +300,7 @@ def test_record_to_event_full_roundtrip_basic() -> None: def test_record_to_event_roundtrip_preserves_content() -> None: """Content (parts) survives the round-trip.""" original = _make_event(text="hello world", author="model") - record = event_to_record(original, "s1") + record = event_to_record(original, "test-app", "test-user", "s1") restored = record_to_event(record) assert restored.content is not None @@ -313,7 +311,7 @@ def test_record_to_event_roundtrip_preserves_content() -> None: def test_record_to_event_roundtrip_preserves_actions() -> None: """EventActions (state_delta) survives the round-trip.""" original = _make_event(state_delta={"key": "v1", "other": 42}) - record = event_to_record(original, "s1") + record = event_to_record(original, "test-app", "test-user", "s1") restored = record_to_event(record) assert restored.actions is not None @@ -323,7 +321,7 @@ def test_record_to_event_roundtrip_preserves_actions() -> None: def test_record_to_event_roundtrip_preserves_custom_metadata() -> None: """custom_metadata survives the round-trip.""" original = _make_event(custom_metadata={"tag": "v2", "score": 0.9}) - record = event_to_record(original, "s1") + record = event_to_record(original, "test-app", "test-user", "s1") restored = record_to_event(record) assert restored.custom_metadata == {"tag": "v2", "score": 0.9} @@ -332,7 +330,7 @@ def test_record_to_event_roundtrip_preserves_custom_metadata() -> None: def test_record_to_event_roundtrip_preserves_branch() -> None: """branch field survives the round-trip.""" original = _make_event(branch="feature-branch") - record = event_to_record(original, "s1") + record = event_to_record(original, "test-app", "test-user", "s1") restored = record_to_event(record) assert restored.branch == "feature-branch" @@ -341,7 +339,7 @@ def test_record_to_event_roundtrip_preserves_branch() -> None: def test_record_to_event_roundtrip_preserves_partial_flag() -> None: """partial flag survives the round-trip.""" original = _make_event(partial=True) - record = event_to_record(original, "s1") + record = event_to_record(original, "test-app", "test-user", "s1") restored = record_to_event(record) assert restored.partial is True @@ -350,7 +348,7 @@ def test_record_to_event_roundtrip_preserves_partial_flag() -> None: def test_record_to_event_roundtrip_preserves_turn_complete() -> None: """turn_complete flag survives the round-trip.""" original = _make_event(turn_complete=True) - record = event_to_record(original, "s1") + record = event_to_record(original, "test-app", "test-user", "s1") restored = record_to_event(record) assert restored.turn_complete is True @@ -360,18 +358,17 @@ def test_record_to_event_roundtrip_preserves_timestamp() -> None: """timestamp survives the round-trip within float precision.""" fixed_ts = datetime(2024, 6, 1, 10, 30, 0, tzinfo=timezone.utc).timestamp() event = Event(id="ts-evt", invocation_id="inv-1", author="user", actions=EventActions(), timestamp=fixed_ts) - record = event_to_record(event, "s1") + record = event_to_record(event, "test-app", "test-user", "s1") restored = record_to_event(record) - assert abs(restored.timestamp - fixed_ts) < 1.0 # within 1 second + assert abs(restored.timestamp - fixed_ts) < 1.0 def test_record_to_event_ignores_unknown_fields_in_event_data() -> None: """Unknown event_data fields are ignored by the current ADK Event model.""" event = _make_event(event_id="extra-fields-evt", author="tool") - record = event_to_record(event, "s1") + record = event_to_record(event, "test-app", "test-user", "s1") - # Inject hypothetical future ADK field into event_data record["event_data"]["hypothetical_v3_field"] = "some_value" # type: ignore[index] restored = record_to_event(record) @@ -446,7 +443,7 @@ def test_record_to_session_with_events_round_trip() -> None: update_time=datetime.now(timezone.utc), ) event = _make_event(text="hello", author="user") - event_record = event_to_record(event, "s1") + event_record = event_to_record(event, "app", "u1", "s1") session = record_to_session(session_record, [event_record]) diff --git a/tests/unit/extensions/test_adk/test_store_config.py b/tests/unit/extensions/test_adk/test_store_config.py index 18fa217f5..3bf4d016f 100644 --- a/tests/unit/extensions/test_adk/test_store_config.py +++ b/tests/unit/extensions/test_adk/test_store_config.py @@ -45,16 +45,18 @@ async def create_session( update_time=datetime.now(), ) - async def get_session(self, session_id: str, *, renew_for: int | timedelta | None = None) -> SessionRecord | None: + async def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: int | timedelta | None = None + ) -> SessionRecord | None: return None - async def update_session_state(self, session_id: str, state: dict[str, Any]) -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: dict[str, Any]) -> None: return None async def list_sessions(self, app_name: str, user_id: str | None = None) -> list[SessionRecord]: return [] - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: return None async def append_event(self, event_record: EventRecord) -> None: @@ -63,25 +65,30 @@ async def append_event(self, event_record: EventRecord) -> None: async def append_event_and_update_state( self, event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, state: dict[str, Any], *, - app_name: str | None = None, - user_id: str | None = None, app_state: dict[str, Any] | None = None, user_state: dict[str, Any] | None = None, ) -> SessionRecord: return SessionRecord( id=session_id, - app_name=app_name or "test-app", - user_id=user_id or "test-user", + app_name=app_name, + user_id=user_id, state=state, create_time=datetime.now(), update_time=datetime.now(), ) async def get_events( - self, session_id: str, after_timestamp: datetime | None = None, limit: int | None = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: datetime | None = None, + limit: int | None = None, ) -> list[EventRecord]: return [] diff --git a/uv.lock b/uv.lock index 6e516fd99..d23cb3708 100644 --- a/uv.lock +++ b/uv.lock @@ -1465,7 +1465,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions", marker = "python_full_version < '3.12'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ @@ -1501,19 +1501,19 @@ wheels = [ [[package]] name = "faker" -version = "40.18.0" +version = "40.19.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "tzdata", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/18/06/70886e82d8f1d2b73454f3a7c1b7405300128df22e70d85a828951366932/faker-40.18.0.tar.gz", hash = "sha256:2207575c0e8f90e6ccd6dbef764de875c614d16d3db4eee9712d9a00087f2e70", size = 1968243, upload-time = "2026-05-14T16:43:04.834Z" } +sdist = { url = "https://files.pythonhosted.org/packages/15/01/28c8ddae8caaf82c929655000963d83e3f01265a9af34e823c2ef2eee8ac/faker-40.19.1.tar.gz", hash = "sha256:76fa71fd3bf320db25e5504eb356f9a76b8a95cd6098524d006f446035b6b89d", size = 1969318, upload-time = "2026-05-22T15:57:37.433Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/84/0b/5c0b2d3a4b7a715f1835dd3f963bfbe841a02ae5cad1df8ee0325dfad235/faker-40.18.0-py3-none-any.whl", hash = "sha256:61a6b94b74605ddb090a065deb197a1c585ae7a874c094cf6693671d271e6083", size = 2006355, upload-time = "2026-05-14T16:43:02.489Z" }, + { url = "https://files.pythonhosted.org/packages/49/b4/40a1ec12ec834604f3848143343baf1c67bc9a1096e401907eaa0d25876a/faker-40.19.1-py3-none-any.whl", hash = "sha256:265259b37c013838baaae34940207288170df385d6c5281413fce56a3504d580", size = 2007643, upload-time = "2026-05-22T15:57:35.867Z" }, ] [[package]] name = "fastapi" -version = "0.136.1" +version = "0.136.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-doc" }, @@ -1522,9 +1522,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "typing-inspection" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5d/45/c130091c2dfa061bbfe3150f2a5091ef1adf149f2a8d2ae769ecaf6e99a2/fastapi-0.136.1.tar.gz", hash = "sha256:7af665ad7acfa0a3baf8983d393b6b471b9da10ede59c60045f49fbc89a0fa7f", size = 397448, upload-time = "2026-04-23T16:49:44.046Z" } +sdist = { url = "https://files.pythonhosted.org/packages/81/2d/ff8d91d7b564d464629a0fd50a4489c97fcb836ac230bf3a7269232a9b1f/fastapi-0.136.3.tar.gz", hash = "sha256:e487fae93ad408e6f47641ee4dfe389864fd7bec92e547ea8498fc13f43e83ab", size = 396410, upload-time = "2026-05-23T18:53:15.192Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5a/ff/2e4eca3ade2c22fe1dea7043b8ee9dabe47753349eb1b56a202de8af6349/fastapi-0.136.1-py3-none-any.whl", hash = "sha256:a6e9d7eeada96c93a4d69cb03836b44fa34e2854accb7244a1ece36cd4781c3f", size = 117683, upload-time = "2026-04-23T16:49:42.437Z" }, + { url = "https://files.pythonhosted.org/packages/e0/82/45359b62a067409bd929ae8a56b8ed13e5a8c8a61194b3c236920999ab83/fastapi-0.136.3-py3-none-any.whl", hash = "sha256:3d2a69bdf04b7e9f3afa292c3bc7a98816bbfafa10bc9b45f3f3700d2f761620", size = 117481, upload-time = "2026-05-23T18:53:16.924Z" }, ] [[package]] @@ -1791,7 +1791,7 @@ s3 = [ [[package]] name = "google-adk" -version = "2.0.0" +version = "2.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiosqlite" }, @@ -1808,6 +1808,7 @@ dependencies = [ { name = "packaging" }, { name = "pydantic" }, { name = "python-dotenv" }, + { name = "python-multipart" }, { name = "pyyaml" }, { name = "requests" }, { name = "starlette" }, @@ -1818,9 +1819,9 @@ dependencies = [ { name = "watchdog" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a5/08/e9af9ab3b0df422f9c9c07251840f8be876694852d3ac06dbe4e15ce01a7/google_adk-2.0.0.tar.gz", hash = "sha256:2f53c70b5de8409d427f0955bc89f1ba30a8397dec5aefa0ac7c3ecd1b4018d4", size = 3337944, upload-time = "2026-05-19T16:17:10.664Z" } +sdist = { url = "https://files.pythonhosted.org/packages/02/d2/58823ea0d5ac32143773d377b014123191ce420480c003190e1f86a9667c/google_adk-2.1.0.tar.gz", hash = "sha256:fd1709bf5e70e5aaa7d148c7b788d2cd00bb659ee10505a731bb2ad609d28968", size = 3340742, upload-time = "2026-05-23T00:13:56.793Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/79/f5/596e879aacad5214945ede50a4c8e4a0811979ecff83c91e0df11ef13961/google_adk-2.0.0-py3-none-any.whl", hash = "sha256:6d06d8e3b9119ccd7721505356f8c0a5253dcfdc426dd9a72e0fef1cf9ed4703", size = 3846211, upload-time = "2026-05-19T16:17:08.081Z" }, + { url = "https://files.pythonhosted.org/packages/0f/5b/56a9992b6f9447437f29e3691a487b16dfdea42aedec16ceced9de88b9f8/google_adk-2.1.0-py3-none-any.whl", hash = "sha256:4ec8a0ccdf8af90f9fa505c740cae921b47590aee3d61bca14f7bfa8bc886e4e", size = 3852259, upload-time = "2026-05-23T00:13:59.611Z" }, ] [[package]] @@ -2610,14 +2611,14 @@ wheels = [ [[package]] name = "joserfc" -version = "1.6.5" +version = "1.6.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3b/dc/5f768c2e391e9afabe5d18e3221346deb5fb6338565f1ccc9e7c6d7befdd/joserfc-1.6.5.tar.gz", hash = "sha256:1482a7db78fb4602e44ed89e51b599d052e091288c7c532c5b694e20149dec48", size = 231881, upload-time = "2026-05-06T04:58:13.408Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/cb/52e479f20804904f5df20ac4539d292dcecd1287aaa33cba1d1def1d9d8e/joserfc-1.6.7.tar.gz", hash = "sha256:6999fe89457069ecacd8cc797c88a805f83054dd883333fa0409f74b46479fd7", size = 232158, upload-time = "2026-05-23T01:46:44.069Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/54/3b/ad1cb22e75c963b1f07c8a2329bf47227ce7e4361df5eb2fb101b2ce33ef/joserfc-1.6.5-py3-none-any.whl", hash = "sha256:e9878a0f8243fe7b95e11fdda81374ca9f7a689e302751579d3dfdeec559675e", size = 70464, upload-time = "2026-05-06T04:58:11.668Z" }, + { url = "https://files.pythonhosted.org/packages/c5/e4/bcf6718b5662894c6831f46296b73cd4b1a2e90c20b6d437e20c4997388c/joserfc-1.6.7-py3-none-any.whl", hash = "sha256:9e51e4a64840aa1734a058258e80a4480e2ff2d5686e480e7c92c954a92fbe05", size = 70603, upload-time = "2026-05-23T01:46:42.129Z" }, ] [[package]] @@ -4365,30 +4366,30 @@ wheels = [ [[package]] name = "polars" -version = "1.40.1" +version = "1.41.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "polars-runtime-32" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b3/8c/bc9bc948058348ed43117cecc3007cd608f395915dae8a00974579a5dab1/polars-1.40.1.tar.gz", hash = "sha256:ab2694134b137596b5a59bfd7b4c54ebbc9b59f9403127f18e32d363777552e8", size = 733574, upload-time = "2026-04-22T19:15:55.507Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/13/fe30b3e2f9ab54a27d82af04fb2edc51c7342cbaa88815e175769a9f5901/polars-1.41.0.tar.gz", hash = "sha256:7cb5465eb66eb868fde779bf5c41c9f2f244481d72c52133e8ed10ba64372e4f", size = 737530, upload-time = "2026-05-22T20:20:56.209Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ea/91/74fc60d94488685a92ac9d49d7ec55f3e91fe9b77942a6235a5fa7f249c3/polars-1.40.1-py3-none-any.whl", hash = "sha256:c0f861219d1319cdea45c4ce4d30355a47176b8f98dcedf95ea8269f131b8abd", size = 828723, upload-time = "2026-04-22T19:14:25.452Z" }, + { url = "https://files.pythonhosted.org/packages/c3/c8/5807714256c5f3de08593113df17f14f99417a451cb2d91530ad94785003/polars-1.41.0-py3-none-any.whl", hash = "sha256:35dcd24de88a198dc50929924f064ba12a0a0a4a3e77e116689491b4b3ab58ac", size = 832953, upload-time = "2026-05-22T20:19:30.958Z" }, ] [[package]] name = "polars-runtime-32" -version = "1.40.1" +version = "1.41.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/54/ba/26d40f039be9f552b5fd7365a621bdfc0f8e912ef77094ae4693491b0bae/polars_runtime_32-1.40.1.tar.gz", hash = "sha256:37f3065615d1bf90d03b5326222df4c5c1f8a5d33e50470aa588e3465e6eb814", size = 2935843, upload-time = "2026-04-22T19:15:57.26Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1e/8f/30dc715ea1135b4b80397edf33fe7b1bb124850e96e38d9918e2b3d6d0b0/polars_runtime_32-1.41.0.tar.gz", hash = "sha256:37ffbe5414f14bf43bcc8e08a0386c97c692e3fd4e87af74529d7f14b1b2d1cb", size = 2985826, upload-time = "2026-05-22T20:20:57.622Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7d/46/22c8af5eed68ac2eeb556e0fa3ca8a7b798e984ceff4450888f3b5ac61fd/polars_runtime_32-1.40.1-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:b748ef652270cc49e9e69f99a035e0eb4d5f856d42bcd6ac4d9d80a40142aa1e", size = 52098755, upload-time = "2026-04-22T19:14:28.555Z" }, - { url = "https://files.pythonhosted.org/packages/c6/3e/48599a38009ca60ff82a6f38c8a621ce3c0286aa7397c7d79e741bd9060e/polars_runtime_32-1.40.1-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:d249b3743e05986060cec0a7aaa542d020df6c6b876e556023a310efd581f9be", size = 46367542, upload-time = "2026-04-22T19:14:32.433Z" }, - { url = "https://files.pythonhosted.org/packages/43/e9/384bc069367a1a36ee31c13782c178dbd039b2b873b772d4a0fc23a2373d/polars_runtime_32-1.40.1-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5987b30e7aa1059d069498496e8dda35afd592b0ac3d46ed87e3ff8df1ad652c", size = 50252104, upload-time = "2026-04-22T19:14:35.945Z" }, - { url = "https://files.pythonhosted.org/packages/15/ef/7d57ceb0651af74194e97ed6583e148d352f03d696090221b8059cdfc90b/polars_runtime_32-1.40.1-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d7f42a8b3f16fc66002cc0f6516f7dd7653396886ae0ed362ab95c0b3408b59", size = 56250788, upload-time = "2026-04-22T19:14:39.743Z" }, - { url = "https://files.pythonhosted.org/packages/10/0f/e4b3ffc748827a14a474ec9c42e45c066050e440fec57e914091d9adda75/polars_runtime_32-1.40.1-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e5f7becc237a7ec9d9a10878dc8e54b73bbf4e2d94a2991c37d7a0b38590d8f9", size = 50432590, upload-time = "2026-04-22T19:14:43.388Z" }, - { url = "https://files.pythonhosted.org/packages/d9/0b/b8d95fbed869fa4caabe9c400e4210374913b376e925e96fdcfa9be6416b/polars_runtime_32-1.40.1-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:992d14cf191dde043d36fbdbc98a65e43fbc7e9a5024cecd45f838ac4988c1ee", size = 54155564, upload-time = "2026-04-22T19:14:47.239Z" }, - { url = "https://files.pythonhosted.org/packages/06/d9/d091d8fb5cbed5e9536adfed955c4c89987a4cc3b8e73ae4532402b91c74/polars_runtime_32-1.40.1-cp310-abi3-win_amd64.whl", hash = "sha256:f78bb2abd00101cbb23cc0cb068f7e36e081057a15d2ec2dde3dda280709f030", size = 51829755, upload-time = "2026-04-22T19:14:50.85Z" }, - { url = "https://files.pythonhosted.org/packages/65/ad/b33c3022a394f3eb55c3310597cec615412a8a33880055eee191d154a628/polars_runtime_32-1.40.1-cp310-abi3-win_arm64.whl", hash = "sha256:b5cbfaf6b085b420b4bfcbe24e8f665076d1cccfdb80c0484c02a023ce205537", size = 45822104, upload-time = "2026-04-22T19:14:54.192Z" }, + { url = "https://files.pythonhosted.org/packages/5f/e7/9d1630d666eca6a67e2096c0ed2c0e18f1355fe440043fd0830de1b71ab6/polars_runtime_32-1.41.0-cp310-abi3-macosx_10_12_x86_64.whl", hash = "sha256:766b60c74550382731b604ed62a385a8403b341bf18282d3fd2f746fa3c4cafd", size = 52163350, upload-time = "2026-05-22T20:19:34.431Z" }, + { url = "https://files.pythonhosted.org/packages/11/b3/01538d51cd2790729ae13c23db44bc787bdfe20867faeb1087afc390c53b/polars_runtime_32-1.41.0-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:bee0d294daca79cedd5749e1bf3373c2d4107eb849fe544a60df6c08abc972ce", size = 46474331, upload-time = "2026-05-22T20:19:37.936Z" }, + { url = "https://files.pythonhosted.org/packages/79/29/efa82e1b3e6711f254df3793f3d3fd99f26ef1bcaffa6533266fa6522de4/polars_runtime_32-1.41.0-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08b3f915468bf00d327b4a1236935a4ec3174dcb163785fcd98185ad1319a503", size = 50358997, upload-time = "2026-05-22T20:19:42.209Z" }, + { url = "https://files.pythonhosted.org/packages/53/18/5a04b06b773047cbf43912ab802eef3c1d50ef7e66d51a41b16726f9bc62/polars_runtime_32-1.41.0-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e72269f768c57229190dba0cb5abb8a1b228e96cc6331273a77a7957576885bd", size = 56332032, upload-time = "2026-05-22T20:19:45.928Z" }, + { url = "https://files.pythonhosted.org/packages/2b/d5/d728ce7a39ea925555db7d2c9f7b5df3ca17568e483db6501783f653e0b9/polars_runtime_32-1.41.0-cp310-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:1532c7560a5c0fd06943080ee42aade721f186becdfbb1baafa622e3199c3b62", size = 50529017, upload-time = "2026-05-22T20:19:49.489Z" }, + { url = "https://files.pythonhosted.org/packages/0b/92/6b2092dfed4278f636499217767204fce19ed72695690e2c4eba99e2892e/polars_runtime_32-1.41.0-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:43368d8a754ee274f1e1099a7c09aae59cd69d22b8a6f83c0f782a00cd3a6662", size = 54244707, upload-time = "2026-05-22T20:19:52.852Z" }, + { url = "https://files.pythonhosted.org/packages/4b/6e/46b43be0f5becbef65843f438de3950ac6f8d0fa0008d7de0025eed00097/polars_runtime_32-1.41.0-cp310-abi3-win_amd64.whl", hash = "sha256:a9bd6095ecadc6799d166b9e8f7183a7ca8ba0a5aef8a426ec41df8ed8b09df7", size = 51918379, upload-time = "2026-05-22T20:19:55.832Z" }, + { url = "https://files.pythonhosted.org/packages/a3/da/30f15f0c3959b70e7a6583eccd37140a30b5c643ca374792d150b3a357df/polars_runtime_32-1.41.0-cp310-abi3-win_arm64.whl", hash = "sha256:ed922400f0eb393345fd7b6874b150eb943af2b816297a3dde03735cb5f3de08", size = 45921961, upload-time = "2026-05-22T20:19:59.525Z" }, ] [[package]] @@ -5294,16 +5295,16 @@ wheels = [ [[package]] name = "pytest-databases" -version = "0.18.0" +version = "0.19.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docker" }, { name = "filelock" }, { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f7/ce/e0458cdb8d84b14156392b4bbc5575b6145b4ae6c3962a500e3d66002da3/pytest_databases-0.18.0.tar.gz", hash = "sha256:d49fa4e85494ec33dd6224affada1ddfdb83736b28f5ae40377220ad5dbbb658", size = 324907, upload-time = "2026-05-12T13:43:52.438Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ab/4e/e899556b3876eec2db9bd630ad3054ced94a9541c26319bec5c4cd00579d/pytest_databases-0.19.0.tar.gz", hash = "sha256:ba7b8e51b551455daf3bd144384f6d4fba23d747b001f071795b02e6be2a3cbf", size = 350034, upload-time = "2026-05-23T18:33:49.262Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/ae/129e2529248eceb422e9051f408f752a78c80628700ffde94cc9edbf2d9d/pytest_databases-0.18.0-py3-none-any.whl", hash = "sha256:e2114c9e36ec7f4ff118453e9511e5cffb85294214838a70a8dee36c40340bd1", size = 40309, upload-time = "2026-05-12T13:43:51.033Z" }, + { url = "https://files.pythonhosted.org/packages/de/f4/d6e2edae47a0421a8531a9805fb787b993380222f00dc25cee1634f55f3c/pytest_databases-0.19.0-py3-none-any.whl", hash = "sha256:d14651e23a716ed6f2317bf9ac317c9e7891db701253abf78e9cef1049e8f26b", size = 41525, upload-time = "2026-05-23T18:33:47.983Z" }, ] [package.optional-dependencies] @@ -5404,6 +5405,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101, upload-time = "2026-03-01T16:00:25.09Z" }, ] +[[package]] +name = "python-multipart" +version = "0.0.29" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/fe/70bd71a6738b09a0bdf6480ca6436b167469ca4578b2a0efbe390b4b0e70/python_multipart-0.0.29.tar.gz", hash = "sha256:643e93849196645e2dbdd81a0f8829a23123ad7f797a84a364c6fb3563f18904", size = 45678, upload-time = "2026-05-17T17:29:47.654Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/cb/769cfc37177252872a45a71f3fbdde9d51b471a3f3c14bfe95dde3407386/python_multipart-0.0.29-py3-none-any.whl", hash = "sha256:2ddcc971cef266225f54f552d8fa10bcfbb1f14446caec199060daac59ff2d69", size = 29640, upload-time = "2026-05-17T17:29:45.69Z" }, +] + [[package]] name = "pytz" version = "2026.2" @@ -5912,20 +5922,20 @@ wheels = [ [[package]] name = "snowballstemmer" -version = "3.0.1" +version = "3.1.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/75/a7/9810d872919697c9d01295633f5d574fb416d47e535f258272ca1f01f447/snowballstemmer-3.0.1.tar.gz", hash = "sha256:6d5eeeec8e9f84d4d56b847692bacf79bc2c8e90c7f80ca4444ff8b6f2e52895", size = 105575, upload-time = "2025-05-09T16:34:51.843Z" } +sdist = { url = "https://files.pythonhosted.org/packages/63/ee/67eef9600338e245ad7838230969a34c823ddbdbccc5e1fc43cd75b55bc9/snowballstemmer-3.1.0.tar.gz", hash = "sha256:fd9e34526b23340cd23ffea6c9f9760974ecc2c2ac9e1d81401443ccdb2a801f", size = 122523, upload-time = "2026-05-24T19:04:19.691Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/78/3565d011c61f5a43488987ee32b6f3f656e7f107ac2782dd57bdd7d91d9a/snowballstemmer-3.0.1-py3-none-any.whl", hash = "sha256:6cd7b3897da8d6c9ffb968a6781fa6532dce9c3618a4b127d920dab764a19064", size = 103274, upload-time = "2025-05-09T16:34:50.371Z" }, + { url = "https://files.pythonhosted.org/packages/49/83/ddbf4533c62dd32667ef1238952abef155f3d3391f5be69a352ad1638a42/snowballstemmer-3.1.0-py3-none-any.whl", hash = "sha256:17e6d1da216aa07db6dad37139ea70cf13c4b2e9a096f6e64a9648fc657d3154", size = 104550, upload-time = "2026-05-24T19:04:18.026Z" }, ] [[package]] name = "soupsieve" -version = "2.8.3" +version = "2.8.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/7b/ae/2d9c981590ed9999a0d91755b47fc74f74de286b0f5cee14c9269041e6c4/soupsieve-2.8.3.tar.gz", hash = "sha256:3267f1eeea4251fb42728b6dfb746edc9acaffc4a45b27e19450b676586e8349", size = 118627, upload-time = "2026-01-20T04:27:02.457Z" } +sdist = { url = "https://files.pythonhosted.org/packages/47/2c/0a5f6f8ee0d5589e48c7640213ed5175d52cf540a06725b628cc1a45d6ce/soupsieve-2.8.4.tar.gz", hash = "sha256:e121fd02e975c695e4e9e8774a5ee35d74714b59307868dcc5319ad2d9e3328e", size = 121110, upload-time = "2026-05-24T13:55:57.154Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/46/2c/1462b1d0a634697ae9e55b3cecdcb64788e8b7d63f54d923fcd0bb140aed/soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95", size = 37016, upload-time = "2026-01-20T04:27:01.012Z" }, + { url = "https://files.pythonhosted.org/packages/5e/f5/0c41cb68dcae6b7de4fac4188a3a9589e21fb31df21ea3a2e888db95e6c9/soupsieve-2.8.4-py3-none-any.whl", hash = "sha256:e7e6b0769c8f51ed59acab6e994b00621096cfb1c640a7509295987388fbaf65", size = 37304, upload-time = "2026-05-24T13:55:55.406Z" }, ] [[package]] @@ -6387,62 +6397,57 @@ wheels = [ [[package]] name = "sqlalchemy" -version = "2.0.49" +version = "2.0.50" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "greenlet", marker = "platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/09/45/461788f35e0364a8da7bda51a1fe1b09762d0c32f12f63727998d85a873b/sqlalchemy-2.0.49.tar.gz", hash = "sha256:d15950a57a210e36dd4cec1aac22787e2a4d57ba9318233e2ef8b2daf9ff2d5f", size = 9898221, upload-time = "2026-04-03T16:38:11.704Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/96/76/f908955139842c362aa877848f42f9249642d5b69e06cee9eae5111da1bd/sqlalchemy-2.0.49-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:42e8804962f9e6f4be2cbaedc0c3718f08f60a16910fa3d86da5a1e3b1bfe60f", size = 2159321, upload-time = "2026-04-03T16:50:11.8Z" }, - { url = "https://files.pythonhosted.org/packages/24/e2/17ba0b7bfbd8de67196889b6d951de269e8a46057d92baca162889beb16d/sqlalchemy-2.0.49-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cc992c6ed024c8c3c592c5fc9846a03dd68a425674900c70122c77ea16c5fb0b", size = 3238937, upload-time = "2026-04-03T16:54:45.731Z" }, - { url = "https://files.pythonhosted.org/packages/90/1e/410dd499c039deacff395eec01a9da057125fcd0c97e3badc252c6a2d6a7/sqlalchemy-2.0.49-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6eb188b84269f357669b62cb576b5b918de10fb7c728a005fa0ebb0b758adce1", size = 3237188, upload-time = "2026-04-03T16:56:53.217Z" }, - { url = "https://files.pythonhosted.org/packages/ab/06/e797a8b98a3993ac4bc785309b9b6d005457fc70238ee6cefa7c8867a92e/sqlalchemy-2.0.49-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:62557958002b69699bdb7f5137c6714ca1133f045f97b3903964f47db97ea339", size = 3190061, upload-time = "2026-04-03T16:54:47.489Z" }, - { url = "https://files.pythonhosted.org/packages/44/d3/5a9f7ef580af1031184b38235da6ac58c3b571df01c9ec061c44b2b0c5a6/sqlalchemy-2.0.49-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:da9b91bca419dc9b9267ffadde24eae9b1a6bffcd09d0a207e5e3af99a03ce0d", size = 3211477, upload-time = "2026-04-03T16:56:55.056Z" }, - { url = "https://files.pythonhosted.org/packages/69/ec/7be8c8cb35f038e963a203e4fe5a028989167cc7299927b7cf297c271e37/sqlalchemy-2.0.49-cp310-cp310-win32.whl", hash = "sha256:5e61abbec255be7b122aa461021daa7c3f310f3e743411a67079f9b3cc91ece3", size = 2119965, upload-time = "2026-04-03T17:00:50.009Z" }, - { url = "https://files.pythonhosted.org/packages/b5/31/0defb93e3a10b0cf7d1271aedd87251a08c3a597ee4f353281769b547b5a/sqlalchemy-2.0.49-cp310-cp310-win_amd64.whl", hash = "sha256:0c98c59075b890df8abfcc6ad632879540f5791c68baebacb4f833713b510e75", size = 2142935, upload-time = "2026-04-03T17:00:51.675Z" }, - { url = "https://files.pythonhosted.org/packages/60/b5/e3617cc67420f8f403efebd7b043128f94775e57e5b84e7255203390ceae/sqlalchemy-2.0.49-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c5070135e1b7409c4161133aa525419b0062088ed77c92b1da95366ec5cbebbe", size = 2159126, upload-time = "2026-04-03T16:50:13.242Z" }, - { url = "https://files.pythonhosted.org/packages/20/9b/91ca80403b17cd389622a642699e5f6564096b698e7cdcbcbb6409898bc4/sqlalchemy-2.0.49-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9ac7a3e245fd0310fd31495eb61af772e637bdf7d88ee81e7f10a3f271bff014", size = 3315509, upload-time = "2026-04-03T16:54:49.332Z" }, - { url = "https://files.pythonhosted.org/packages/b1/61/0722511d98c54de95acb327824cb759e8653789af2b1944ab1cc69d32565/sqlalchemy-2.0.49-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d4e5a0ceba319942fa6b585cf82539288a61e314ef006c1209f734551ab9536", size = 3315014, upload-time = "2026-04-03T16:56:56.376Z" }, - { url = "https://files.pythonhosted.org/packages/46/55/d514a653ffeb4cebf4b54c47bec32ee28ad89d39fafba16eeed1d81dccd5/sqlalchemy-2.0.49-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3ddcb27fb39171de36e207600116ac9dfd4ae46f86c82a9bf3934043e80ebb88", size = 3267388, upload-time = "2026-04-03T16:54:51.272Z" }, - { url = "https://files.pythonhosted.org/packages/2f/16/0dcc56cb6d3335c1671a2258f5d2cb8267c9a2260e27fde53cbfb1b3540a/sqlalchemy-2.0.49-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:32fe6a41ad97302db2931f05bb91abbcc65b5ce4c675cd44b972428dd2947700", size = 3289602, upload-time = "2026-04-03T16:56:57.63Z" }, - { url = "https://files.pythonhosted.org/packages/51/6c/f8ab6fb04470a133cd80608db40aa292e6bae5f162c3a3d4ab19544a67af/sqlalchemy-2.0.49-cp311-cp311-win32.whl", hash = "sha256:46d51518d53edfbe0563662c96954dc8fcace9832332b914375f45a99b77cc9a", size = 2119044, upload-time = "2026-04-03T17:00:53.455Z" }, - { url = "https://files.pythonhosted.org/packages/c4/59/55a6d627d04b6ebb290693681d7683c7da001eddf90b60cfcc41ee907978/sqlalchemy-2.0.49-cp311-cp311-win_amd64.whl", hash = "sha256:951d4a210744813be63019f3df343bf233b7432aadf0db54c75802247330d3af", size = 2143642, upload-time = "2026-04-03T17:00:54.769Z" }, - { url = "https://files.pythonhosted.org/packages/49/b3/2de412451330756aaaa72d27131db6dde23995efe62c941184e15242a5fa/sqlalchemy-2.0.49-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4bbccb45260e4ff1b7db0be80a9025bb1e6698bdb808b83fff0000f7a90b2c0b", size = 2157681, upload-time = "2026-04-03T16:53:07.132Z" }, - { url = "https://files.pythonhosted.org/packages/50/84/b2a56e2105bd11ebf9f0b93abddd748e1a78d592819099359aa98134a8bf/sqlalchemy-2.0.49-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fb37f15714ec2652d574f021d479e78cd4eb9d04396dca36568fdfffb3487982", size = 3338976, upload-time = "2026-04-03T17:07:40Z" }, - { url = "https://files.pythonhosted.org/packages/2c/fa/65fcae2ed62f84ab72cf89536c7c3217a156e71a2c111b1305ab6f0690e2/sqlalchemy-2.0.49-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3bb9ec6436a820a4c006aad1ac351f12de2f2dbdaad171692ee457a02429b672", size = 3351937, upload-time = "2026-04-03T17:12:23.374Z" }, - { url = "https://files.pythonhosted.org/packages/f8/2f/6fd118563572a7fe475925742eb6b3443b2250e346a0cc27d8d408e73773/sqlalchemy-2.0.49-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8d6efc136f44a7e8bc8088507eaabbb8c2b55b3dbb63fe102c690da0ddebe55e", size = 3281646, upload-time = "2026-04-03T17:07:41.949Z" }, - { url = "https://files.pythonhosted.org/packages/c5/d7/410f4a007c65275b9cf82354adb4bb8ba587b176d0a6ee99caa16fe638f8/sqlalchemy-2.0.49-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e06e617e3d4fd9e51d385dfe45b077a41e9d1b033a7702551e3278ac597dc750", size = 3316695, upload-time = "2026-04-03T17:12:25.642Z" }, - { url = "https://files.pythonhosted.org/packages/d9/95/81f594aa60ded13273a844539041ccf1e66c5a7bed0a8e27810a3b52d522/sqlalchemy-2.0.49-cp312-cp312-win32.whl", hash = "sha256:83101a6930332b87653886c01d1ee7e294b1fe46a07dd9a2d2b4f91bcc88eec0", size = 2117483, upload-time = "2026-04-03T17:05:40.896Z" }, - { url = "https://files.pythonhosted.org/packages/47/9e/fd90114059175cac64e4fafa9bf3ac20584384d66de40793ae2e2f26f3bb/sqlalchemy-2.0.49-cp312-cp312-win_amd64.whl", hash = "sha256:618a308215b6cececb6240b9abde545e3acdabac7ae3e1d4e666896bf5ba44b4", size = 2144494, upload-time = "2026-04-03T17:05:42.282Z" }, - { url = "https://files.pythonhosted.org/packages/ae/81/81755f50eb2478eaf2049728491d4ea4f416c1eb013338682173259efa09/sqlalchemy-2.0.49-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df2d441bacf97022e81ad047e1597552eb3f83ca8a8f1a1fdd43cd7fe3898120", size = 2154547, upload-time = "2026-04-03T16:53:08.64Z" }, - { url = "https://files.pythonhosted.org/packages/a2/bc/3494270da80811d08bcfa247404292428c4fe16294932bce5593f215cad9/sqlalchemy-2.0.49-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8e20e511dc15265fb433571391ba313e10dd8ea7e509d51686a51313b4ac01a2", size = 3280782, upload-time = "2026-04-03T17:07:43.508Z" }, - { url = "https://files.pythonhosted.org/packages/cd/f5/038741f5e747a5f6ea3e72487211579d8cbea5eb9827a9cbd61d0108c4bd/sqlalchemy-2.0.49-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47604cb2159f8bbd5a1ab48a714557156320f20871ee64d550d8bf2683d980d3", size = 3297156, upload-time = "2026-04-03T17:12:27.697Z" }, - { url = "https://files.pythonhosted.org/packages/88/50/a6af0ff9dc954b43a65ca9b5367334e45d99684c90a3d3413fc19a02d43c/sqlalchemy-2.0.49-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:22d8798819f86720bc646ab015baff5ea4c971d68121cb36e2ebc2ee43ead2b7", size = 3228832, upload-time = "2026-04-03T17:07:45.38Z" }, - { url = "https://files.pythonhosted.org/packages/bc/d1/5f6bdad8de0bf546fc74370939621396515e0cdb9067402d6ba1b8afbe9a/sqlalchemy-2.0.49-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9b1c058c171b739e7c330760044803099c7fff11511e3ab3573e5327116a9c33", size = 3267000, upload-time = "2026-04-03T17:12:29.657Z" }, - { url = "https://files.pythonhosted.org/packages/f7/30/ad62227b4a9819a5e1c6abff77c0f614fa7c9326e5a3bdbee90f7139382b/sqlalchemy-2.0.49-cp313-cp313-win32.whl", hash = "sha256:a143af2ea6672f2af3f44ed8f9cd020e9cc34c56f0e8db12019d5d9ecf41cb3b", size = 2115641, upload-time = "2026-04-03T17:05:43.989Z" }, - { url = "https://files.pythonhosted.org/packages/17/3a/7215b1b7d6d49dc9a87211be44562077f5f04f9bb5a59552c1c8e2d98173/sqlalchemy-2.0.49-cp313-cp313-win_amd64.whl", hash = "sha256:12b04d1db2663b421fe072d638a138460a51d5a862403295671c4f3987fb9148", size = 2141498, upload-time = "2026-04-03T17:05:45.7Z" }, - { url = "https://files.pythonhosted.org/packages/28/4b/52a0cb2687a9cd1648252bb257be5a1ba2c2ded20ba695c65756a55a15a4/sqlalchemy-2.0.49-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:24bd94bb301ec672d8f0623eba9226cc90d775d25a0c92b5f8e4965d7f3a1518", size = 3560807, upload-time = "2026-04-03T16:58:31.666Z" }, - { url = "https://files.pythonhosted.org/packages/8c/d8/fda95459204877eed0458550d6c7c64c98cc50c2d8d618026737de9ed41a/sqlalchemy-2.0.49-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a51d3db74ba489266ef55c7a4534eb0b8db9a326553df481c11e5d7660c8364d", size = 3527481, upload-time = "2026-04-03T17:06:00.155Z" }, - { url = "https://files.pythonhosted.org/packages/ff/0a/2aac8b78ac6487240cf7afef8f203ca783e8796002dc0cf65c4ee99ff8bb/sqlalchemy-2.0.49-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:55250fe61d6ebfd6934a272ee16ef1244e0f16b7af6cd18ab5b1fc9f08631db0", size = 3468565, upload-time = "2026-04-03T16:58:33.414Z" }, - { url = "https://files.pythonhosted.org/packages/a5/3d/ce71cfa82c50a373fd2148b3c870be05027155ce791dc9a5dcf439790b8b/sqlalchemy-2.0.49-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:46796877b47034b559a593d7e4b549aba151dae73f9e78212a3478161c12ab08", size = 3477769, upload-time = "2026-04-03T17:06:02.787Z" }, - { url = "https://files.pythonhosted.org/packages/d5/e8/0a9f5c1f7c6f9ca480319bf57c2d7423f08d31445974167a27d14483c948/sqlalchemy-2.0.49-cp313-cp313t-win32.whl", hash = "sha256:9c4969a86e41454f2858256c39bdfb966a20961e9b58bf8749b65abf447e9a8d", size = 2143319, upload-time = "2026-04-03T17:02:04.328Z" }, - { url = "https://files.pythonhosted.org/packages/0e/51/fb5240729fbec73006e137c4f7a7918ffd583ab08921e6ff81a999d6517a/sqlalchemy-2.0.49-cp313-cp313t-win_amd64.whl", hash = "sha256:b9870d15ef00e4d0559ae10ee5bc71b654d1f20076dbe8bc7ed19b4c0625ceba", size = 2175104, upload-time = "2026-04-03T17:02:05.989Z" }, - { url = "https://files.pythonhosted.org/packages/55/33/bf28f618c0a9597d14e0b9ee7d1e0622faff738d44fe986ee287cdf1b8d0/sqlalchemy-2.0.49-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:233088b4b99ebcbc5258c755a097aa52fbf90727a03a5a80781c4b9c54347a2e", size = 2156356, upload-time = "2026-04-03T16:53:09.914Z" }, - { url = "https://files.pythonhosted.org/packages/d1/a7/5f476227576cb8644650eff68cc35fa837d3802b997465c96b8340ced1e2/sqlalchemy-2.0.49-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:57ca426a48eb2c682dae8204cd89ea8ab7031e2675120a47924fabc7caacbc2a", size = 3276486, upload-time = "2026-04-03T17:07:46.9Z" }, - { url = "https://files.pythonhosted.org/packages/2e/84/efc7c0bf3a1c5eef81d397f6fddac855becdbb11cb38ff957888603014a7/sqlalchemy-2.0.49-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:685e93e9c8f399b0c96a624799820176312f5ceef958c0f88215af4013d29066", size = 3281479, upload-time = "2026-04-03T17:12:32.226Z" }, - { url = "https://files.pythonhosted.org/packages/91/68/bb406fa4257099c67bd75f3f2261b129c63204b9155de0d450b37f004698/sqlalchemy-2.0.49-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:9e0400fa22f79acc334d9a6b185dc00a44a8e6578aa7e12d0ddcd8434152b187", size = 3226269, upload-time = "2026-04-03T17:07:48.678Z" }, - { url = "https://files.pythonhosted.org/packages/67/84/acb56c00cca9f251f437cb49e718e14f7687505749ea9255d7bd8158a6df/sqlalchemy-2.0.49-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a05977bffe9bffd2229f477fa75eabe3192b1b05f408961d1bebff8d1cd4d401", size = 3248260, upload-time = "2026-04-03T17:12:34.381Z" }, - { url = "https://files.pythonhosted.org/packages/56/19/6a20ea25606d1efd7bd1862149bb2a22d1451c3f851d23d887969201633f/sqlalchemy-2.0.49-cp314-cp314-win32.whl", hash = "sha256:0f2fa354ba106eafff2c14b0cc51f22801d1e8b2e4149342023bd6f0955de5f5", size = 2118463, upload-time = "2026-04-03T17:05:47.093Z" }, - { url = "https://files.pythonhosted.org/packages/cf/4f/8297e4ed88e80baa1f5aa3c484a0ee29ef3c69c7582f206c916973b75057/sqlalchemy-2.0.49-cp314-cp314-win_amd64.whl", hash = "sha256:77641d299179c37b89cf2343ca9972c88bb6eef0d5fc504a2f86afd15cd5adf5", size = 2144204, upload-time = "2026-04-03T17:05:48.694Z" }, - { url = "https://files.pythonhosted.org/packages/1f/33/95e7216df810c706e0cd3655a778604bbd319ed4f43333127d465a46862d/sqlalchemy-2.0.49-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c1dc3368794d522f43914e03312202523cc89692f5389c32bea0233924f8d977", size = 3565474, upload-time = "2026-04-03T16:58:35.128Z" }, - { url = "https://files.pythonhosted.org/packages/0c/a4/ed7b18d8ccf7f954a83af6bb73866f5bc6f5636f44c7731fbb741f72cc4f/sqlalchemy-2.0.49-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7c821c47ecfe05cc32140dcf8dc6fd5d21971c86dbd56eabfe5ba07a64910c01", size = 3530567, upload-time = "2026-04-03T17:06:04.587Z" }, - { url = "https://files.pythonhosted.org/packages/73/a3/20faa869c7e21a827c4a2a42b41353a54b0f9f5e96df5087629c306df71e/sqlalchemy-2.0.49-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:9c04bff9a5335eb95c6ecf1c117576a0aa560def274876fd156cfe5510fccc61", size = 3474282, upload-time = "2026-04-03T16:58:37.131Z" }, - { url = "https://files.pythonhosted.org/packages/b7/50/276b9a007aa0764304ad467eceb70b04822dc32092492ee5f322d559a4dc/sqlalchemy-2.0.49-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7f605a456948c35260e7b2a39f8952a26f077fd25653c37740ed186b90aaa68a", size = 3480406, upload-time = "2026-04-03T17:06:07.176Z" }, - { url = "https://files.pythonhosted.org/packages/e5/c3/c80fcdb41905a2df650c2a3e0337198b6848876e63d66fe9188ef9003d24/sqlalchemy-2.0.49-cp314-cp314t-win32.whl", hash = "sha256:6270d717b11c5476b0cbb21eedc8d4dbb7d1a956fd6c15a23e96f197a6193158", size = 2149151, upload-time = "2026-04-03T17:02:07.281Z" }, - { url = "https://files.pythonhosted.org/packages/05/52/9f1a62feab6ed368aff068524ff414f26a6daebc7361861035ae00b05530/sqlalchemy-2.0.49-cp314-cp314t-win_amd64.whl", hash = "sha256:275424295f4256fd301744b8f335cff367825d270f155d522b30c7bf49903ee7", size = 2184178, upload-time = "2026-04-03T17:02:08.623Z" }, - { url = "https://files.pythonhosted.org/packages/e5/30/8519fdde58a7bdf155b714359791ad1dc018b47d60269d5d160d311fdc36/sqlalchemy-2.0.49-py3-none-any.whl", hash = "sha256:ec44cfa7ef1a728e88ad41674de50f6db8cfdb3e2af84af86e0041aaf02d43d0", size = 1942158, upload-time = "2026-04-03T16:53:44.135Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/57/da/6fbf010c8ebb347679d0d100b22fe9ba5e13fd04046c5df7280d2f0bf706/sqlalchemy-2.0.50.tar.gz", hash = "sha256:af5607d11ef90fd6a5c0549fe0045dce1663d427426bcfb506dcb5346a85a3b9", size = 9907424, upload-time = "2026-05-24T19:20:04.018Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/a9/812a775bd8c1af0966d660238d005baf25e9bced1f038c8e71f00aa637a7/sqlalchemy-2.0.50-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7af6eeb84985bf840ba779018ff9424d61ff69b52e66b8789d3c8da7bf5341b2", size = 2161617, upload-time = "2026-05-24T20:00:00.761Z" }, + { url = "https://files.pythonhosted.org/packages/d5/74/5a6bc5496e9be8f740fbf80f9e6bd4ab965c8a80870eb07ab015e360957a/sqlalchemy-2.0.50-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0fe7822866f3a9fc5f3db21a290ce8961a53050115f05edf9402b6a5feb92a9f", size = 3244104, upload-time = "2026-05-24T20:07:38.158Z" }, + { url = "https://files.pythonhosted.org/packages/81/55/b260d8df2adc9bb0bf294f67b5f802ff0d84d99442b536b9efd0ea72d447/sqlalchemy-2.0.50-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8e1b0f6a4dcd9b4839e2320afb5df37a6981cbc20ff9c423ae11c5537bdbd21", size = 3243039, upload-time = "2026-05-24T20:14:23.765Z" }, + { url = "https://files.pythonhosted.org/packages/e5/6d/58714005cbf370f16c3f30d30324a43be10069efcfe764f7236a2e851947/sqlalchemy-2.0.50-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e195687f1af431c9515416288373b323b6eb599f774409814e89e9d603a56e39", size = 3195017, upload-time = "2026-05-24T20:07:40.086Z" }, + { url = "https://files.pythonhosted.org/packages/30/e8/67527fee039bd3e1a6ce3f03d2b62fd87ab9099c17052810d79496727b66/sqlalchemy-2.0.50-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ea1a8a2db4b2217d456c8d7a873bfc605f06fe3584d315264ea18c2a17585d0b", size = 3215308, upload-time = "2026-05-24T20:14:26.034Z" }, + { url = "https://files.pythonhosted.org/packages/94/b2/dd3155a6a6706cb89adecf5ee6e0512f7b0ee5cf3e6f4cde67d3c20ebfda/sqlalchemy-2.0.50-cp310-cp310-win32.whl", hash = "sha256:68b154b08088b4ec32bb4d2958bfbb50e57549f91a4cd3e7f928e3553ed69031", size = 2121637, upload-time = "2026-05-24T20:08:06.401Z" }, + { url = "https://files.pythonhosted.org/packages/93/a1/a09c463ee3e7764b5ce5bd19a7f0b6eefbde62e637439ab58498cdbd6b47/sqlalchemy-2.0.50-cp310-cp310-win_amd64.whl", hash = "sha256:66e374271ecb7101273f57af1a62446a953d327eec4f8089147de57c591bbacc", size = 2144673, upload-time = "2026-05-24T20:08:07.936Z" }, + { url = "https://files.pythonhosted.org/packages/b6/5d/3172686af1770e4de2805f919a51441085f589ddadf3dd76ec582f84f497/sqlalchemy-2.0.50-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1aa6e403663a9c43c8fef7ce4bdb4cf48bcd8d352e91deda2a99f963270bd508", size = 2161366, upload-time = "2026-05-24T20:00:02.061Z" }, + { url = "https://files.pythonhosted.org/packages/0f/90/e98dedea3c3e663a17afcd003a34ba45efdac2cea3b6f2e4585e2b1e2537/sqlalchemy-2.0.50-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:51b637a84f9fa35ae1f9017e786cb142974a25305085e1b378b3647a67f65ad3", size = 3318926, upload-time = "2026-05-24T20:07:42.369Z" }, + { url = "https://files.pythonhosted.org/packages/3b/4f/501308c2babb62c11753ecb4ee88ba9eef019419a4d6cbf7cb13e2bad353/sqlalchemy-2.0.50-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2dab927761d9108550f0cf8e66ff21af56f907a0ce0a689793db615e2b55f62c", size = 3319199, upload-time = "2026-05-24T20:14:28.551Z" }, + { url = "https://files.pythonhosted.org/packages/ac/39/d88996c5e03ed6248c3a788d20f0b8d8b376b9f8a495e4bab9df7c72d2f8/sqlalchemy-2.0.50-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:545eae198d37bcf837a10ede3684e2af32458d6f35c597c35c2de7502dc38fc4", size = 3270301, upload-time = "2026-05-24T20:07:44.917Z" }, + { url = "https://files.pythonhosted.org/packages/42/1b/1ae0e65161b51cc43e5ca75430ef79d80e23b5042d645586c2c342c3b92e/sqlalchemy-2.0.50-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0fec460e18cdbb4c7773531122ce9a27e96c6ca17af3933941d94da475ad2c86", size = 3293465, upload-time = "2026-05-24T20:14:30.501Z" }, + { url = "https://files.pythonhosted.org/packages/83/29/17c0003f2c0dfa6d1b97672475707e3ec5980db09defd7fa20beb6833bbd/sqlalchemy-2.0.50-cp311-cp311-win32.whl", hash = "sha256:e6e814658818fd165e749e3d8490ef16cc7f379a118c37ada8b0589ffbaaac22", size = 2120694, upload-time = "2026-05-24T20:08:09.237Z" }, + { url = "https://files.pythonhosted.org/packages/c9/18/280d00654cc19d1fccf236fa5070f6dd04b84dde6f1b2e637bde0ff340a7/sqlalchemy-2.0.50-cp311-cp311-win_amd64.whl", hash = "sha256:1c5f858fe79c9f5d8fda065c06186356acb7f8df3cd52dbd5ee3f200e4b144f5", size = 2145315, upload-time = "2026-05-24T20:08:10.952Z" }, + { url = "https://files.pythonhosted.org/packages/be/b0/a9d19b43f38f878b1278bca5b00b909f7540d41494396dd2561f9ad0956d/sqlalchemy-2.0.50-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23ae23d8b9d344d30d0a92f06d45825024a5790f1c1dd4cf452636a50d3e58cb", size = 2159807, upload-time = "2026-05-24T19:27:53.086Z" }, + { url = "https://files.pythonhosted.org/packages/f5/2c/191dd58a248fd2cfd4780fa82c375c505e4ad98c8b522fa69ec492130d77/sqlalchemy-2.0.50-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:47b71b933e7b4ebad407c8fdfd70d2c4f08b78b3238bb30eebdd6eb32ca51b89", size = 3343358, upload-time = "2026-05-24T20:09:29.279Z" }, + { url = "https://files.pythonhosted.org/packages/8a/2b/514fce8a7df81cf5bad7ff7865de7ac0c5776a38cc043475c4703eb7fe8b/sqlalchemy-2.0.50-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:110fdac56ace278949f00de805edacbd6141e382d992f9ba28238b3a0827a600", size = 3357994, upload-time = "2026-05-24T20:17:13.495Z" }, + { url = "https://files.pythonhosted.org/packages/35/a6/a0e283f5494f92b0d77e319ff77e437b1ffe4a051ba67c81d53234825475/sqlalchemy-2.0.50-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f5e4ac70e9e757f6b3e87c0491ff034442ecd8dfd36d041a50564c322dafc0e", size = 3289399, upload-time = "2026-05-24T20:09:32.239Z" }, + { url = "https://files.pythonhosted.org/packages/b7/96/1b07325ba71752d6a028b77d07bed1483ad545f794e8b1dc89b3ba3b3c68/sqlalchemy-2.0.50-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:724f3dcbe53dd0151e3cb5e7ec4ba4c620bede579caacd16275dc35ce06e8615", size = 3321216, upload-time = "2026-05-24T20:17:15.581Z" }, + { url = "https://files.pythonhosted.org/packages/ed/8e/bad6ed253e8a99edfc99af02f7173ec48a1d3ed1b9b35a1b8bc1700900cc/sqlalchemy-2.0.50-cp312-cp312-win32.whl", hash = "sha256:1208050441471d003b7c8cb4054fb084f185cf35ac3f0ea270803865bca9939a", size = 2119194, upload-time = "2026-05-24T19:50:04.943Z" }, + { url = "https://files.pythonhosted.org/packages/b6/2d/314a6690dda4b9cfc571eab1a63cf6fe6e1470aa3759ccda6aa016ee0f5a/sqlalchemy-2.0.50-cp312-cp312-win_amd64.whl", hash = "sha256:9d1af51558029a156a70986b7df88f042b3d158d7c8d8fb5072912d4b32d89c7", size = 2146186, upload-time = "2026-05-24T19:50:06.74Z" }, + { url = "https://files.pythonhosted.org/packages/0b/c4/c42356b527296e9862f67990efce31ef78b4cf69cd3f80873a528a060320/sqlalchemy-2.0.50-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:06a9210bdc5f4298cff0781087e2ff45683922252dacc452846373a58761f093", size = 2156697, upload-time = "2026-05-24T19:27:54.764Z" }, + { url = "https://files.pythonhosted.org/packages/60/a1/b1a70e3c4365ac7fe9e347f3710f19b562c866fb96d45e3c891588789a7b/sqlalchemy-2.0.50-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b53784972ade4f8174b9aa661f31a06f8a936d2cfdd602913ff3c6dd40ae873", size = 3284260, upload-time = "2026-05-24T20:09:34.195Z" }, + { url = "https://files.pythonhosted.org/packages/3f/4a/f3ac3caa19f263d57b0a47f8c91bbf56583dc2d3fc63acfbf644abb24fe0/sqlalchemy-2.0.50-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31648fa14460537e768a7303b078e4344d208e0d23e06867c1f376a227ed82db", size = 3302280, upload-time = "2026-05-24T20:17:17.825Z" }, + { url = "https://files.pythonhosted.org/packages/66/55/ccada3e3d62254587819749a0bc69f41173eb48a6e385d10e66d32a9c88e/sqlalchemy-2.0.50-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:03f4323c980ad0e918cc9e5369b015f759f4e534db5bbaf4dc36832c10d05064", size = 3231580, upload-time = "2026-05-24T20:09:36.406Z" }, + { url = "https://files.pythonhosted.org/packages/05/f6/6809349130a2de0e109e7f00fd7d431da9565b9b2868b32ee684754f672b/sqlalchemy-2.0.50-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2b9dcc43afef8ac157cd92fce96985d6b8b0cfbd3df4d666f66b4d55a75d202f", size = 3269375, upload-time = "2026-05-24T20:17:20.34Z" }, + { url = "https://files.pythonhosted.org/packages/48/84/278a811ef4e07be9c89dc5cdd7be833268509a66a68c4897cf585e67428f/sqlalchemy-2.0.50-cp313-cp313-win32.whl", hash = "sha256:60922d6599065ddca2c6f376b9aa2f41a6b85a271725e0909490bbc50b1998a5", size = 2117229, upload-time = "2026-05-24T19:50:08.215Z" }, + { url = "https://files.pythonhosted.org/packages/f6/1c/067cc6187ed32d2ec222fe6d2643acc1659a6d0659f8a7cbc5ad3ae83280/sqlalchemy-2.0.50-cp313-cp313-win_amd64.whl", hash = "sha256:287086e67275a212c4582d166a6fb03a65ccc5551d80866270ce0dd9f34eccd3", size = 2143126, upload-time = "2026-05-24T19:50:09.691Z" }, + { url = "https://files.pythonhosted.org/packages/df/32/10ac51b4be7cdecd7e93d069251c86dfbf70b7adbd7c67b48ccea6c49e1c/sqlalchemy-2.0.50-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c966932507a4d7d0a37314927dbfcd89720e3f37d2a1e3352e7ae7939fa8e8a0", size = 2158519, upload-time = "2026-05-24T19:27:56.472Z" }, + { url = "https://files.pythonhosted.org/packages/5a/76/e703d2f7681d7d66c4c891af3f07c7ccf4c76ad7f18351de035b5eda007a/sqlalchemy-2.0.50-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:faffef4bcc20a1892e65e155293d99d60855bbbc79250ab712819cfd56a8e6bb", size = 3282063, upload-time = "2026-05-24T20:09:38.57Z" }, + { url = "https://files.pythonhosted.org/packages/31/26/ef168b184a25701f9995e8fb7e503fafd7a99c1c77cda1bc1a26ea2ed486/sqlalchemy-2.0.50-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6c206aec519a2e7bd08abbfb33436e325fd22c632d9c21a9047e376ce241646e", size = 3287069, upload-time = "2026-05-24T20:17:21.942Z" }, + { url = "https://files.pythonhosted.org/packages/c2/15/765acc2bc693bccc43ca4a95d5b69750da8aaf6db1b5c616536e087f8920/sqlalchemy-2.0.50-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:bef4ac756363227ef6402a75fee025a4bc690f92328e825868939b3b3a446a6d", size = 3230453, upload-time = "2026-05-24T20:09:40.398Z" }, + { url = "https://files.pythonhosted.org/packages/63/61/08e03c3adbf5db0087a0b6816746fec8f3032fb2f7fc899a9bb9b2a48ce4/sqlalchemy-2.0.50-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:96fbee6b19c19cd1556c8bf9419447cf2ec149ffcab7ab64348c23e54ef8547f", size = 3252413, upload-time = "2026-05-24T20:17:24.067Z" }, + { url = "https://files.pythonhosted.org/packages/03/0c/370a1f2db38436c615e10134c8a37de3688e74084792380695f3f5083860/sqlalchemy-2.0.50-cp314-cp314-win32.whl", hash = "sha256:8f00e3eb43ba30eb1b238ee03a8a62309486d1321eda3328bb611e0340033ad8", size = 2120063, upload-time = "2026-05-24T19:50:11.08Z" }, + { url = "https://files.pythonhosted.org/packages/7f/a0/fe92bb9817863bc13ba093bda931979a26cc2ca69f8e8f26d07add3d7c6f/sqlalchemy-2.0.50-cp314-cp314-win_amd64.whl", hash = "sha256:15708c613cd5005b7dffe1f66ee6a63ee8f5e46799f71c70ebad74178c676a39", size = 2145830, upload-time = "2026-05-24T19:50:12.452Z" }, + { url = "https://files.pythonhosted.org/packages/cc/ff/e5640a98a0b2f491eb8fde10fb6c773621a2e44340de231fafcc9370f4a9/sqlalchemy-2.0.50-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:3699dac4be410e97049a1658e9480da9cde956594aa0f3aebc60b88f21c5ba70", size = 2178435, upload-time = "2026-05-24T19:42:58.889Z" }, + { url = "https://files.pythonhosted.org/packages/b7/85/337116e186f1236375b5fb70c21cfac98e8e8ab0d3a47be838dc47a59e08/sqlalchemy-2.0.50-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f96233858e3df43932ac11589e22520da6e8aeb624b03fedfeebb0e8ea213086", size = 3566059, upload-time = "2026-05-24T20:01:20.848Z" }, + { url = "https://files.pythonhosted.org/packages/96/34/bb0e190e161c3c2c24314a65add57218be14a4a9486886b7f5047c1ff7c8/sqlalchemy-2.0.50-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c4e70c46fad30c3bcc6a4708bc0130a3173e11a5b25f0ea4a9d8911b450f1f52", size = 3535366, upload-time = "2026-05-24T20:03:56.768Z" }, + { url = "https://files.pythonhosted.org/packages/df/5a/a7f759f97e4fd499c5d4e4488c760d5a7fbecf3028b465a04274fcd52384/sqlalchemy-2.0.50-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1918a3cf564d16d95bca7301005f41ab2ad50b07cd3b9da50d3ed986db148d6a", size = 3474879, upload-time = "2026-05-24T20:01:23.058Z" }, + { url = "https://files.pythonhosted.org/packages/9d/d9/2907ea38eb60687d297bf9c39e5ee58053c87b57fe8a9cae97090cecbf10/sqlalchemy-2.0.50-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:b00098cdbdbd38c7be3d568b0c9c3122b8c0ec62b911b57cd5e6e0254d60a76d", size = 3486117, upload-time = "2026-05-24T20:03:59.052Z" }, + { url = "https://files.pythonhosted.org/packages/f2/e3/5aa06f167559f8c0bdae487e297d23ba548150ab016a3418265d617a4985/sqlalchemy-2.0.50-cp314-cp314t-win32.whl", hash = "sha256:1fbd55a969d7ac44a98e3dec75016074f809fa08f871585ace58dde110d1bf3e", size = 2150823, upload-time = "2026-05-24T20:08:58.644Z" }, + { url = "https://files.pythonhosted.org/packages/65/9b/112fb8f977582d7489d036e409e3723948bcf5320b3ac465f3c481bbe8f9/sqlalchemy-2.0.50-cp314-cp314t-win_amd64.whl", hash = "sha256:c5c3cdb753a9004183e1ccb634b41611654c989e61bc68617ce878e46d6f1e51", size = 2185794, upload-time = "2026-05-24T20:09:00.319Z" }, + { url = "https://files.pythonhosted.org/packages/d0/10/f7220e9b784d295d241c86ed99aeb537f92afcd469a64861f2717e9bb077/sqlalchemy-2.0.50-py3-none-any.whl", hash = "sha256:92064363517a3ff8212b5a93b8c62876579d8dfd1ca5b561335f30152d884fa9", size = 1943861, upload-time = "2026-05-24T19:59:01.119Z" }, ] [package.optional-dependencies] @@ -7556,16 +7561,16 @@ wheels = [ [[package]] name = "uvicorn" -version = "0.47.0" +version = "0.48.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, { name = "h11" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f6/b1/8e7077a8641086aea449e1b5752a570f1b5906c64e0a33cd6d93b63a066b/uvicorn-0.47.0.tar.gz", hash = "sha256:7c9a0ea1a9414106bbab7324609c162d8fa0cdcdcb703060987269d77c7bb533", size = 90582, upload-time = "2026-05-14T18:16:54.455Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e6/bf/f6544ba992ddb9a6077343a576f9844f7f8f06ab819aefd00206e9255f18/uvicorn-0.48.0.tar.gz", hash = "sha256:a5504207195d08c2511bf9125ede5ac4a4b71725d519e758d01dcf0bc2d31c37", size = 91074, upload-time = "2026-05-24T12:08:41.925Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/15/41/ac2dfdbc1f60c7af4f994c7a335cfa7040c01642b605d65f611cecc2a1e4/uvicorn-0.47.0-py3-none-any.whl", hash = "sha256:2c5715bc12d1892d84752049f400cd1c3cb018514967fdfeb97640443a6a9432", size = 71301, upload-time = "2026-05-14T18:16:51.762Z" }, + { url = "https://files.pythonhosted.org/packages/01/be/72532be3da7acc5fdfbccdb95215cd04f995a0886532a5b423f929cda4cc/uvicorn-0.48.0-py3-none-any.whl", hash = "sha256:48097851328b87ec36117d3d575234519eb58c2b22d79666e9bbc6c49a761dad", size = 71410, upload-time = "2026-05-24T12:08:40.258Z" }, ] [[package]] From a259d3f59229c77c517502587fb85bf0f6c851ab Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Tue, 26 May 2026 00:15:46 +0000 Subject: [PATCH 26/29] Refactor ADK tests and configurations for clarity and consistency - Removed unnecessary xfail markers in CockroachDB and Psycopg tests. - Updated table names in SQLite and Spanner tests to reflect new schema. - Changed metadata table name in BigQuery tests for consistency. - Enhanced test assertions for session store configurations and migrations. - Consolidated nested configuration tests into flat structure for ADK. - Removed deprecated lifecycle and capability tests. - Added new migration tests for ADK cutover scenarios. - Improved session service tests to validate state persistence and event handling. --- .../examples/extensions/adk/backend_config.py | 4 +- docs/extensions/adk/backends.rst | 4 +- docs/extensions/adk/migrations.rst | 4 +- docs/extensions/adk/schema.rst | 10 +- sqlspec/adapters/aiosqlite/adk/store.py | 4 +- .../adapters/cockroach_asyncpg/adk/store.py | 2 +- .../adapters/cockroach_psycopg/adk/store.py | 81 +-- sqlspec/adapters/duckdb/adk/store.py | 4 +- sqlspec/adapters/psqlpy/adk/store.py | 22 +- sqlspec/adapters/psycopg/adk/store.py | 85 +-- sqlspec/adapters/sqlite/adk/store.py | 4 +- sqlspec/config.py | 675 ++---------------- sqlspec/extensions/adk/_capabilities.py | 101 --- sqlspec/extensions/adk/_config_utils.py | 111 +-- sqlspec/extensions/adk/_lifecycle.py | 105 --- sqlspec/extensions/adk/_versioning.py | 142 ---- sqlspec/extensions/adk/artifact/__init__.py | 2 +- sqlspec/extensions/adk/artifact/store.py | 2 +- sqlspec/extensions/adk/memory/__init__.py | 2 +- sqlspec/extensions/adk/memory/presets.py | 49 +- sqlspec/extensions/adk/memory/service.py | 2 +- sqlspec/extensions/adk/memory/store.py | 4 +- .../adk/migrations/0001_create_adk_tables.py | 179 +---- .../adk/migrations/0002_reset_adk_tables.py | 118 +++ sqlspec/extensions/adk/service.py | 25 +- sqlspec/extensions/adk/store.py | 4 +- .../adapters/_adk_contract_helpers.py | 2 + .../extensions/adk/test_memory_store.py | 2 +- .../asyncpg/extensions/adk/conftest.py | 2 +- .../adk/test_scoped_state_contract.py | 16 - .../adk/test_scoped_state_contract.py | 10 +- .../adk/test_scoped_state_contract.py | 3 - .../adk/test_scoped_state_contract.py | 10 +- .../extensions/adk/test_memory_store.py | 2 +- tests/unit/adapters/test_bigquery_adk.py | 4 +- .../adapters/test_psycopg/test_adk_store.py | 2 +- .../adapters/test_spanner/test_adk_store.py | 2 +- .../extensions/test_adk/test_capabilities.py | 72 -- .../test_adk/test_config_resolution.py | 59 +- .../extensions/test_adk/test_converters.py | 10 +- .../test_adk/test_embedding_presets.py | 11 +- .../test_adk/test_lifecycle_config.py | 86 --- .../extensions/test_adk/test_migrations.py | 151 ++++ .../unit/extensions/test_adk/test_service.py | 149 +++- .../extensions/test_adk/test_store_config.py | 15 +- .../extensions/test_adk/test_versioning.py | 86 --- 46 files changed, 679 insertions(+), 1760 deletions(-) delete mode 100644 sqlspec/extensions/adk/_capabilities.py delete mode 100644 sqlspec/extensions/adk/_lifecycle.py delete mode 100644 sqlspec/extensions/adk/_versioning.py create mode 100644 sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py delete mode 100644 tests/unit/extensions/test_adk/test_capabilities.py delete mode 100644 tests/unit/extensions/test_adk/test_lifecycle_config.py create mode 100644 tests/unit/extensions/test_adk/test_migrations.py delete mode 100644 tests/unit/extensions/test_adk/test_versioning.py diff --git a/docs/examples/extensions/adk/backend_config.py b/docs/examples/extensions/adk/backend_config.py index 5a5ce6ff4..7984dfdc0 100644 --- a/docs/examples/extensions/adk/backend_config.py +++ b/docs/examples/extensions/adk/backend_config.py @@ -15,8 +15,8 @@ def test_adk_backend_config() -> None: "events_table": "adk_event", "app_state_table": "adk_app_state", "user_state_table": "adk_user_state", - "metadata_table": "adk_internal_metadata", - "memory_table": "adk_memory_entries", + "metadata_table": "adk_metadata", + "memory_table": "adk_memory", "memory_use_fts": True, } diff --git a/docs/extensions/adk/backends.rst b/docs/extensions/adk/backends.rst index 37af698f5..353c232c3 100644 --- a/docs/extensions/adk/backends.rst +++ b/docs/extensions/adk/backends.rst @@ -328,8 +328,8 @@ All backends are configured through ``extension_config["adk"]``: "events_table": "adk_event", "app_state_table": "adk_app_state", "user_state_table": "adk_user_state", - "metadata_table": "adk_internal_metadata", - "memory_table": "adk_memory_entries", + "metadata_table": "adk_metadata", + "memory_table": "adk_memory", "memory_use_fts": True, "owner_id_column": "tenant_id INTEGER NOT NULL", } diff --git a/docs/extensions/adk/migrations.rst b/docs/extensions/adk/migrations.rst index 630fbc6bc..2ede744a8 100644 --- a/docs/extensions/adk/migrations.rst +++ b/docs/extensions/adk/migrations.rst @@ -81,9 +81,9 @@ note the following schema changes: store ``app:`` and ``user:`` scoped keys. Raw ``adk_session.state`` rows now contain only session-scoped keys; ``SQLSpecSessionService.get_session()`` returns the merged ADK view. -- **Internal metadata table**: New ``adk_internal_metadata`` table seeded with +- **Internal metadata table**: New ``adk_metadata`` table seeded with ``schema_version = 1``. -- **Artifact table**: New table (``adk_artifact_versions``) for artifact +- **Artifact table**: New table (``adk_artifact``) for artifact metadata. Create this table when enabling the artifact service. - **BigQuery**: Treated as an analytics-replica backend. Use Spanner or a PostgreSQL-family adapter for latency-sensitive live session state. diff --git a/docs/extensions/adk/schema.rst b/docs/extensions/adk/schema.rst index 31706dd48..ebf5f1f05 100644 --- a/docs/extensions/adk/schema.rst +++ b/docs/extensions/adk/schema.rst @@ -234,7 +234,7 @@ Internal Metadata Table The metadata table stores ADK schema metadata used by migrations and future schema-version dispatch. -Default name: ``adk_internal_metadata`` +Default name: ``adk_metadata`` .. list-table:: :header-rows: 1 @@ -290,7 +290,7 @@ Memory Table The memory table stores long-term context entries that agents can search and reference across sessions. -Default name: ``adk_memory_entries`` +Default name: ``adk_memory`` .. list-table:: :header-rows: 1 @@ -333,7 +333,7 @@ Concrete artifact metadata stores use this table shape to store versioning metadata for binary artifacts. Content bytes are stored separately in object storage; this table tracks ownership, versioning, and canonical URIs. -Default name: ``adk_artifact_versions`` +Default name: ``adk_artifact`` .. list-table:: :header-rows: 1 @@ -386,8 +386,8 @@ All table names are configurable: "events_table": "my_event", # default: "adk_event" "app_state_table": "my_app_state", # default: "adk_app_state" "user_state_table": "my_user_state", # default: "adk_user_state" - "metadata_table": "my_adk_metadata", # default: "adk_internal_metadata" - "memory_table": "my_memory", # default: "adk_memory_entries" + "metadata_table": "my_adk_metadata", # default: "adk_metadata" + "memory_table": "my_memory", # default: "adk_memory" "artifact_table": "my_artifacts", # artifact metadata stores } }, diff --git a/sqlspec/adapters/aiosqlite/adk/store.py b/sqlspec/adapters/aiosqlite/adk/store.py index 7cbd3ebb0..c02e11c9a 100644 --- a/sqlspec/adapters/aiosqlite/adk/store.py +++ b/sqlspec/adapters/aiosqlite/adk/store.py @@ -772,7 +772,7 @@ class AiosqliteADKMemoryStore(BaseAsyncADKMemoryStore["AiosqliteConfig"]): connection_config={"database": ":memory:"}, extension_config={ "adk": { - "memory_table": "adk_memory_entries", + "memory_table": "adk_memory", "memory_use_fts": False, "memory_max_results": 20, } @@ -800,7 +800,7 @@ def __init__(self, config: "AiosqliteConfig") -> None: Notes: Configuration is read from config.extension_config["adk"]: - - memory_table: Memory table name (default: "adk_memory_entries") + - memory_table: Memory table name (default: "adk_memory") - memory_use_fts: Enable full-text search when supported (default: False) - memory_max_results: Max search results (default: 20) - owner_id_column: Optional owner FK column DDL (default: None) diff --git a/sqlspec/adapters/cockroach_asyncpg/adk/store.py b/sqlspec/adapters/cockroach_asyncpg/adk/store.py index a7caba06a..5bf8948ef 100644 --- a/sqlspec/adapters/cockroach_asyncpg/adk/store.py +++ b/sqlspec/adapters/cockroach_asyncpg/adk/store.py @@ -226,11 +226,11 @@ async def append_event_and_update_state( event_record["timestamp"], event_record["event_data"], ) - row = await conn.fetchrow(update_sql, state, app_name, user_id, session_id) if app_state: await conn.execute(app_upsert_sql, app_name, app_state) if user_state: await conn.execute(user_upsert_sql, app_name, user_id, user_state) + row = await conn.fetchrow(update_sql, state, app_name, user_id, session_id) if row is None: msg = f"Session {session_id} not found during append_event_and_update_state." diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py index 7038f51eb..c22c3ace3 100644 --- a/sqlspec/adapters/cockroach_psycopg/adk/store.py +++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py @@ -4,6 +4,7 @@ from psycopg import errors from psycopg import sql as pg_sql +from psycopg.rows import dict_row from psycopg.types.json import Jsonb from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord @@ -109,7 +110,7 @@ async def create_session( """ params = (session_id, app_name, user_id, state_json) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), params) await conn.commit() @@ -139,7 +140,7 @@ async def get_session( params = (app_name, user_id, session_id) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), params) row = await cur.fetchone() @@ -164,14 +165,14 @@ async def update_session_state(self, app_name: str, user_id: str, session_id: st WHERE app_name = %s AND user_id = %s AND id = %s """ - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), (Jsonb(state), app_name, user_id, session_id)) await conn.commit() async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), (app_name, user_id, session_id)) await conn.commit() @@ -194,7 +195,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis params = (app_name, user_id) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), params) rows = await cur.fetchall() @@ -222,7 +223,7 @@ async def append_event(self, event_record: EventRecord) -> None: event_data_value = event_record["event_data"] jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute( sql.encode(), ( @@ -275,7 +276,7 @@ async def append_event_and_update_state( event_data_value = event_record["event_data"] jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute( insert_sql.encode(), ( @@ -335,7 +336,7 @@ async def get_events( params.append(limit) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), tuple(params)) rows = await cur.fetchall() @@ -358,7 +359,7 @@ async def delete_expired_events(self, before: "datetime") -> int: sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), (before,)) await conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -369,7 +370,7 @@ async def delete_idle_sessions(self, updated_before: "datetime") -> int: sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), (updated_before,)) await conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -380,7 +381,7 @@ async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = %s" try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), (app_name,)) row = await cur.fetchone() return row["state"] if row is not None else None @@ -391,7 +392,7 @@ async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = %s AND user_id = %s" try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), (app_name, user_id)) row = await cur.fetchone() return row["state"] if row is not None else None @@ -404,7 +405,7 @@ async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None VALUES (%s, %s, CURRENT_TIMESTAMP) """ - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), (app_name, Jsonb(state))) await conn.commit() @@ -414,7 +415,7 @@ async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, VALUES (%s, %s, %s, CURRENT_TIMESTAMP) """ - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), (app_name, user_id, Jsonb(state))) await conn.commit() @@ -422,7 +423,7 @@ async def get_metadata(self, key: str) -> "str | None": sql = f"SELECT value FROM {self._metadata_table} WHERE key = %s" try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), (key,)) row = await cur.fetchone() return row["value"] if row is not None else None @@ -435,7 +436,7 @@ async def set_metadata(self, key: str, value: str) -> None: VALUES (%s, %s) """ - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), (key, value)) await conn.commit() @@ -772,7 +773,7 @@ def _create_session( """ params = (session_id, app_name, user_id, state_json) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), params) conn.commit() @@ -802,7 +803,7 @@ def _get_session( params = (app_name, user_id, session_id) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), params) row = cur.fetchone() @@ -827,14 +828,14 @@ def _update_session_state(self, app_name: str, user_id: str, session_id: str, st WHERE app_name = %s AND user_id = %s AND id = %s """ - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), (Jsonb(state), app_name, user_id, session_id)) conn.commit() def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), (app_name, user_id, session_id)) conn.commit() @@ -857,7 +858,7 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses params = (app_name, user_id) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), params) rows = cur.fetchall() @@ -915,7 +916,7 @@ def _append_event_and_update_state( event_data_value = event_record["event_data"] jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute( insert_sql.encode(), ( @@ -957,7 +958,7 @@ def _insert_event(self, event_record: EventRecord) -> None: event_data_value = event_record["event_data"] jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute( sql.encode(), ( @@ -998,7 +999,7 @@ def _get_events( params.append(limit) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), tuple(params)) rows = cur.fetchall() @@ -1021,7 +1022,7 @@ def _delete_expired_events(self, before: "datetime") -> int: sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), (before,)) conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -1032,7 +1033,7 @@ def _delete_idle_sessions(self, updated_before: "datetime") -> int: sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), (updated_before,)) conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -1043,7 +1044,7 @@ def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = %s" try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), (app_name,)) row = cur.fetchone() return row["state"] if row is not None else None @@ -1054,7 +1055,7 @@ def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = %s AND user_id = %s" try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), (app_name, user_id)) row = cur.fetchone() return row["state"] if row is not None else None @@ -1067,7 +1068,7 @@ def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: VALUES (%s, %s, CURRENT_TIMESTAMP) """ - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), (app_name, Jsonb(state))) conn.commit() @@ -1077,7 +1078,7 @@ def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any] VALUES (%s, %s, %s, CURRENT_TIMESTAMP) """ - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), (app_name, user_id, Jsonb(state))) conn.commit() @@ -1085,7 +1086,7 @@ def _get_metadata(self, key: str) -> "str | None": sql = f"SELECT value FROM {self._metadata_table} WHERE key = %s" try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), (key,)) row = cur.fetchone() return row["value"] if row is not None else None @@ -1098,7 +1099,7 @@ def _set_metadata(self, key: str, value: str) -> None: VALUES (%s, %s) """ - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), (key, value)) conn.commit() @@ -1155,7 +1156,7 @@ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: " ON CONFLICT (event_id) DO NOTHING """).format(table=pg_sql.Identifier(self._memory_table)) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: for entry in entries: if self._owner_id_column_name: await cur.execute(query, _build_insert_params_with_owner(entry, owner_id)) @@ -1199,7 +1200,7 @@ async def search_entries( params = (app_name, user_id, search_param, effective_limit) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), params) rows = await cur.fetchall() columns = [col[0] for col in cur.description or []] @@ -1214,7 +1215,7 @@ async def delete_entries_by_session(self, session_id: str) -> int: raise RuntimeError(msg) sql = f"DELETE FROM {self._memory_table} WHERE session_id = %s" - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), (session_id,)) await conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -1228,7 +1229,7 @@ async def delete_entries_older_than(self, days: int) -> int: DELETE FROM {self._memory_table} WHERE inserted_at < CURRENT_TIMESTAMP - INTERVAL '{days} days' """ - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode()) await conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -1380,7 +1381,7 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec ON CONFLICT (event_id) DO NOTHING """).format(table=pg_sql.Identifier(self._memory_table)) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: for entry in entries: if self._owner_id_column_name: cur.execute(query, _build_insert_params_with_owner(entry, owner_id)) @@ -1422,7 +1423,7 @@ def _search_entries( params = (app_name, user_id, search_param, effective_limit) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), params) rows = cur.fetchall() columns = [col[0] for col in cur.description or []] @@ -1437,7 +1438,7 @@ def _delete_entries_by_session(self, session_id: str) -> int: raise RuntimeError(msg) sql = f"DELETE FROM {self._memory_table} WHERE session_id = %s" - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), (session_id,)) conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -1451,7 +1452,7 @@ def _delete_entries_older_than(self, days: int) -> int: DELETE FROM {self._memory_table} WHERE inserted_at < CURRENT_TIMESTAMP - INTERVAL '{days} days' """ - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode()) conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py index 63c750426..acbba51f5 100644 --- a/sqlspec/adapters/duckdb/adk/store.py +++ b/sqlspec/adapters/duckdb/adk/store.py @@ -876,7 +876,7 @@ class DuckdbADKMemoryStore(BaseAsyncADKMemoryStore["DuckDBConfig"]): database="app.ddb", extension_config={ "adk": { - "memory_table": "adk_memory_entries", + "memory_table": "adk_memory", "memory_max_results": 20, } } @@ -906,7 +906,7 @@ def __init__(self, config: "DuckDBConfig") -> None: Notes: Configuration is read from config.extension_config["adk"]: - - memory_table: Memory table name (default: "adk_memory_entries") + - memory_table: Memory table name (default: "adk_memory") - memory_use_fts: Enable full-text search when supported (default: False) - memory_max_results: Max search results (default: 20) - owner_id_column: Optional owner FK column DDL (default: None) diff --git a/sqlspec/adapters/psqlpy/adk/store.py b/sqlspec/adapters/psqlpy/adk/store.py index 91658f7f2..11217a835 100644 --- a/sqlspec/adapters/psqlpy/adk/store.py +++ b/sqlspec/adapters/psqlpy/adk/store.py @@ -118,7 +118,7 @@ async def get_session( create_time=row["create_time"], update_time=row["update_time"], ) - except psqlpy.exceptions.DatabaseError as e: + except (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.ConnectionExecuteError) as e: error_msg = str(e).lower() if "does not exist" in error_msg or "relation" in error_msg: return None @@ -174,7 +174,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis ) for row in rows ] - except psqlpy.exceptions.DatabaseError as e: + except (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.ConnectionExecuteError) as e: error_msg = str(e).lower() if "does not exist" in error_msg or "relation" in error_msg: return [] @@ -313,7 +313,7 @@ async def get_events( ) for row in rows ] - except psqlpy.exceptions.DatabaseError as e: + except (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.ConnectionExecuteError) as e: error_msg = str(e).lower() if "does not exist" in error_msg or "relation" in error_msg: return [] @@ -330,7 +330,7 @@ async def delete_expired_events(self, before: "datetime") -> int: count = int(count_rows[0]["count"]) if count_rows else 0 await conn.execute(delete_sql, [before]) return count - except psqlpy.exceptions.DatabaseError as e: + except (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.ConnectionExecuteError) as e: error_msg = str(e).lower() if "does not exist" in error_msg or "relation" in error_msg: return 0 @@ -347,7 +347,7 @@ async def delete_idle_sessions(self, updated_before: "datetime") -> int: count = int(count_rows[0]["count"]) if count_rows else 0 await conn.execute(delete_sql, [updated_before]) return count - except psqlpy.exceptions.DatabaseError as e: + except (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.ConnectionExecuteError) as e: error_msg = str(e).lower() if "does not exist" in error_msg or "relation" in error_msg: return 0 @@ -361,7 +361,7 @@ async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": result = await conn.fetch(sql, [app_name]) rows: list[dict[str, Any]] = result.result() if result else [] return rows[0]["state"] if rows else None - except psqlpy.exceptions.DatabaseError as e: + except (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.ConnectionExecuteError) as e: error_msg = str(e).lower() if "does not exist" in error_msg or "relation" in error_msg: return None @@ -375,7 +375,7 @@ async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | result = await conn.fetch(sql, [app_name, user_id]) rows: list[dict[str, Any]] = result.result() if result else [] return rows[0]["state"] if rows else None - except psqlpy.exceptions.DatabaseError as e: + except (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.ConnectionExecuteError) as e: error_msg = str(e).lower() if "does not exist" in error_msg or "relation" in error_msg: return None @@ -413,7 +413,7 @@ async def get_metadata(self, key: str) -> "str | None": result = await conn.fetch(sql, [key]) rows: list[dict[str, Any]] = result.result() if result else [] return rows[0]["value"] if rows else None - except psqlpy.exceptions.DatabaseError as e: + except (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.ConnectionExecuteError) as e: error_msg = str(e).lower() if "does not exist" in error_msg or "relation" in error_msg: return None @@ -629,7 +629,7 @@ async def search_entries( except Exception as exc: # pragma: no cover - defensive fallback logger.warning("FTS search failed; falling back to simple search: %s", exc) return await self._search_entries_simple(query, app_name, user_id, effective_limit) - except psqlpy.exceptions.DatabaseError as e: + except (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.ConnectionExecuteError) as e: error_msg = str(e).lower() if "does not exist" in error_msg or "relation" in error_msg: return [] @@ -647,7 +647,7 @@ async def delete_entries_by_session(self, session_id: str) -> int: count = int(count_rows[0]["count"]) if count_rows else 0 await conn.execute(delete_sql, [session_id]) return count - except psqlpy.exceptions.DatabaseError as e: + except (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.ConnectionExecuteError) as e: error_msg = str(e).lower() if "does not exist" in error_msg or "relation" in error_msg: return 0 @@ -671,7 +671,7 @@ async def delete_entries_older_than(self, days: int) -> int: count = int(count_rows[0]["count"]) if count_rows else 0 await conn.execute(delete_sql, []) return count - except psqlpy.exceptions.DatabaseError as e: + except (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.ConnectionExecuteError) as e: error_msg = str(e).lower() if "does not exist" in error_msg or "relation" in error_msg: return 0 diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py index 15def115f..0bc1c481b 100644 --- a/sqlspec/adapters/psycopg/adk/store.py +++ b/sqlspec/adapters/psycopg/adk/store.py @@ -4,6 +4,7 @@ from psycopg import errors from psycopg import sql as pg_sql +from psycopg.rows import dict_row from psycopg.types.json import Jsonb from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord @@ -110,7 +111,7 @@ async def create_session( """).format(table=pg_sql.Identifier(self._session_table)) params = (session_id, app_name, user_id, Jsonb(state)) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, params) return await self.get_session(app_name, user_id, session_id) # type: ignore[return-value] @@ -135,7 +136,7 @@ async def get_session( params = (app_name, user_id, session_id) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, params) row = await cur.fetchone() @@ -160,7 +161,7 @@ async def update_session_state(self, app_name: str, user_id: str, session_id: st WHERE app_name = %s AND user_id = %s AND id = %s """).format(table=pg_sql.Identifier(self._session_table)) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, (Jsonb(state), app_name, user_id, session_id)) async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: @@ -168,7 +169,7 @@ async def delete_session(self, app_name: str, user_id: str, session_id: str) -> table=pg_sql.Identifier(self._session_table) ) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, (app_name, user_id, session_id)) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": @@ -190,7 +191,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis params = (app_name, user_id) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, params) rows = await cur.fetchall() @@ -218,7 +219,7 @@ async def append_event(self, event_record: EventRecord) -> None: event_data_value = event_record["event_data"] jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute( query, ( @@ -273,7 +274,7 @@ async def append_event_and_update_state( event_data_value = event_record["event_data"] jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute( insert_query, ( @@ -340,7 +341,7 @@ async def get_events( ) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, tuple(params)) rows = await cur.fetchall() @@ -365,7 +366,7 @@ async def delete_expired_events(self, before: "datetime") -> int: ) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, (before,)) await conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -378,7 +379,7 @@ async def delete_idle_sessions(self, updated_before: "datetime") -> int: ) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, (updated_before,)) await conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -391,7 +392,7 @@ async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": ) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, (app_name,)) row = await cur.fetchone() return row["state"] if row is not None else None @@ -404,7 +405,7 @@ async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | ) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, (app_name, user_id)) row = await cur.fetchone() return row["state"] if row is not None else None @@ -420,7 +421,7 @@ async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None update_time = CURRENT_TIMESTAMP """).format(table=pg_sql.Identifier(self._app_state_table)) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, (app_name, Jsonb(state))) async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: @@ -432,7 +433,7 @@ async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, update_time = CURRENT_TIMESTAMP """).format(table=pg_sql.Identifier(self._user_state_table)) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, (app_name, user_id, Jsonb(state))) async def get_metadata(self, key: str) -> "str | None": @@ -441,7 +442,7 @@ async def get_metadata(self, key: str) -> "str | None": ) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, (key,)) row = await cur.fetchone() return row["value"] if row is not None else None @@ -455,7 +456,7 @@ async def set_metadata(self, key: str, value: str) -> None: ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value """).format(table=pg_sql.Identifier(self._metadata_table)) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, (key, value)) async def _get_create_sessions_table_sql(self) -> str: @@ -792,7 +793,7 @@ def _create_session( """).format(table=pg_sql.Identifier(self._session_table)) params = (session_id, app_name, user_id, Jsonb(state)) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, params) result = self._get_session(app_name, user_id, session_id) @@ -821,7 +822,7 @@ def _get_session( params = (app_name, user_id, session_id) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, params) row = cur.fetchone() @@ -846,7 +847,7 @@ def _update_session_state(self, app_name: str, user_id: str, session_id: str, st WHERE app_name = %s AND user_id = %s AND id = %s """).format(table=pg_sql.Identifier(self._session_table)) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, (Jsonb(state), app_name, user_id, session_id)) def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: @@ -854,7 +855,7 @@ def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: table=pg_sql.Identifier(self._session_table) ) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, (app_name, user_id, session_id)) def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": @@ -876,7 +877,7 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses params = (app_name, user_id) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, params) rows = cur.fetchall() @@ -904,7 +905,7 @@ def _insert_event(self, event_record: EventRecord) -> None: event_data_value = event_record["event_data"] jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute( insert_query, ( @@ -960,7 +961,7 @@ def _append_event_and_update_state( event_data_value = event_record["event_data"] jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute( insert_query, ( @@ -1027,7 +1028,7 @@ def _get_events( ) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, tuple(params)) rows = cur.fetchall() @@ -1052,7 +1053,7 @@ def _delete_expired_events(self, before: "datetime") -> int: ) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, (before,)) conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -1065,7 +1066,7 @@ def _delete_idle_sessions(self, updated_before: "datetime") -> int: ) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, (updated_before,)) conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -1078,7 +1079,7 @@ def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": ) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, (app_name,)) row = cur.fetchone() return row["state"] if row is not None else None @@ -1091,7 +1092,7 @@ def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None ) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, (app_name, user_id)) row = cur.fetchone() return row["state"] if row is not None else None @@ -1107,7 +1108,7 @@ def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: update_time = CURRENT_TIMESTAMP """).format(table=pg_sql.Identifier(self._app_state_table)) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, (app_name, Jsonb(state))) conn.commit() @@ -1120,7 +1121,7 @@ def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any] update_time = CURRENT_TIMESTAMP """).format(table=pg_sql.Identifier(self._user_state_table)) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, (app_name, user_id, Jsonb(state))) conn.commit() @@ -1130,7 +1131,7 @@ def _get_metadata(self, key: str) -> "str | None": ) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, (key,)) row = cur.fetchone() return row["value"] if row is not None else None @@ -1144,7 +1145,7 @@ def _set_metadata(self, key: str, value: str) -> None: ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value """).format(table=pg_sql.Identifier(self._metadata_table)) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, (key, value)) conn.commit() @@ -1204,7 +1205,7 @@ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: " ON CONFLICT (event_id) DO NOTHING """).format(table=pg_sql.Identifier(self._memory_table)) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: for entry in entries: if self._owner_id_column_name: await cur.execute(query, _build_insert_params_with_owner(entry, owner_id)) @@ -1241,7 +1242,7 @@ async def delete_entries_by_session(self, session_id: str) -> int: table=pg_sql.Identifier(self._memory_table) ) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql, (session_id,)) return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -1254,7 +1255,7 @@ async def delete_entries_older_than(self, days: int) -> int: """ ).format(table=pg_sql.Identifier(self._memory_table), interval=pg_sql.Literal(f"{days} days")) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql) return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -1313,7 +1314,7 @@ async def _search_entries_fts(self, query: str, app_name: str, user_id: str, lim """ ).format(table=pg_sql.Identifier(self._memory_table)) params: tuple[str, str, str, str, int] = (query, app_name, user_id, query, limit) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql, params) rows = await cur.fetchall() return _rows_to_records(rows) @@ -1333,7 +1334,7 @@ async def _search_entries_simple(self, query: str, app_name: str, user_id: str, ).format(table=pg_sql.Identifier(self._memory_table)) pattern = f"%{query}%" params: tuple[str, str, str, int] = (app_name, user_id, pattern, limit) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql, params) rows = await cur.fetchall() return _rows_to_records(rows) @@ -1452,7 +1453,7 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec ON CONFLICT (event_id) DO NOTHING """).format(table=pg_sql.Identifier(self._memory_table)) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: for entry in entries: if self._owner_id_column_name: cur.execute(query, _build_insert_params_with_owner(entry, owner_id)) @@ -1498,7 +1499,7 @@ def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: in """ ).format(table=pg_sql.Identifier(self._memory_table)) params: tuple[str, str, str, str, int] = (query, app_name, user_id, query, limit) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql, params) rows = cur.fetchall() return _rows_to_records(rows) @@ -1518,7 +1519,7 @@ def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit: ).format(table=pg_sql.Identifier(self._memory_table)) pattern = f"%{query}%" params: tuple[str, str, str, int] = (app_name, user_id, pattern, limit) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql, params) rows = cur.fetchall() return _rows_to_records(rows) @@ -1529,7 +1530,7 @@ def _delete_entries_by_session(self, session_id: str) -> int: table=pg_sql.Identifier(self._memory_table) ) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql, (session_id,)) return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 @@ -1542,7 +1543,7 @@ def _delete_entries_older_than(self, days: int) -> int: """ ).format(table=pg_sql.Identifier(self._memory_table), interval=pg_sql.Literal(f"{days} days")) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql) return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py index 53579d7b7..3a8e0a12a 100644 --- a/sqlspec/adapters/sqlite/adk/store.py +++ b/sqlspec/adapters/sqlite/adk/store.py @@ -868,7 +868,7 @@ class SqliteADKMemoryStore(BaseAsyncADKMemoryStore["SqliteConfig"]): database="app.db", extension_config={ "adk": { - "memory_table": "adk_memory_entries", + "memory_table": "adk_memory", "memory_use_fts": False, "memory_max_results": 20, } @@ -896,7 +896,7 @@ def __init__(self, config: "SqliteConfig") -> None: Notes: Configuration is read from config.extension_config["adk"]: - - memory_table: Memory table name (default: "adk_memory_entries") + - memory_table: Memory table name (default: "adk_memory") - memory_use_fts: Enable full-text search when supported (default: False) - memory_max_results: Max search results (default: 20) - owner_id_column: Optional owner FK column DDL (default: None) diff --git a/sqlspec/config.py b/sqlspec/config.py index 5a0eb0025..c90d66c5d 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -492,688 +492,107 @@ class SanicConfig(TypedDict): """Framework name for SQLCommenter attributes. Default: 'sanic'.""" -class ADKPartitionConfig(TypedDict): - """Configuration for table partitioning and sharding strategies. - - Controls how ADK tables are partitioned across backends that support it. - Backends without native partitioning support ignore these settings. - - Example: - extension_config={ - "adk": { - "partitioning": { - "strategy": "range", - "partition_key": "created_at", - "interval": "month", - } - } - } - """ - - strategy: NotRequired[Literal["range", "list", "hash"]] - """Partitioning strategy. Default: None (no partitioning). - - - range: Partition by range of values (e.g., time-based) - - list: Partition by discrete value lists - - hash: Partition by hash of the partition key - - Supported by: PostgreSQL, MySQL 8+, Oracle, Spanner. - Ignored by: SQLite, DuckDB. - """ - - partition_key: NotRequired[str] - """Column name used as the partition key. - - For range partitioning with time-based data, this is typically a timestamp column - like 'created_at'. For hash partitioning, this is typically the primary key. - """ - - session_partition_key: NotRequired[str] - """Session-table partition key override for adapters that create separate ADK tables.""" - - events_partition_key: NotRequired[str] - """Event-table partition key override for adapters that create separate ADK tables.""" - - memory_partition_key: NotRequired[str] - """Memory-table partition key override for adapters that create separate ADK tables.""" - - interval: NotRequired[str] - """Partition interval for range partitioning. - - Examples: 'day', 'week', 'month', 'year'. - Only meaningful when strategy is 'range'. - """ - - partition_count: NotRequired[int] - """Number of hash partitions for adapters that support hash-partitioned ADK tables.""" - - initial_less_than: NotRequired[str] - """Initial range-partition upper bound for adapters that require a seed partition.""" - - -class ADKRetentionConfig(TypedDict): - """Configuration for data retention and TTL policies. - - Controls automatic cleanup of expired data. Backends with native TTL support - (CockroachDB Row-Level TTL, Spanner Row Deletion Policy) use database-level - enforcement. Others fall back to application-level sweep queries. - - Example: - extension_config={ - "adk": { - "retention": { - "session_ttl_seconds": 86400, - "event_ttl_seconds": 604800, - "memory_ttl_seconds": 0, - } - } - } - """ - - session_ttl_seconds: NotRequired[int] - """TTL for session records in seconds. Default: 0 (no expiry). - - When set, sessions older than this threshold are eligible for cleanup. - Backends with native TTL (CockroachDB, Spanner) enforce this at the database level. - Others require application-level cleanup via periodic sweep. - """ - - event_ttl_seconds: NotRequired[int] - """TTL for event records in seconds. Default: 0 (no expiry). - - When set, events older than this threshold are eligible for cleanup. - """ - - memory_ttl_seconds: NotRequired[int] - """TTL for memory entries in seconds. Default: 0 (no expiry). - - When set, memory entries older than this threshold are eligible for cleanup. - """ - - sweep_interval_seconds: NotRequired[int] - """Interval between application-level cleanup sweeps in seconds. Default: 3600 (1 hour). - - Only used when the backend does not support native TTL enforcement. - Set to 0 to disable automatic sweeps (manual cleanup only). - """ - - -class ADKCompressionConfig(TypedDict): - """Configuration for table-level compression. - - Controls compression of ADK table storage. Support and algorithms vary by backend. - - Example: - extension_config={ - "adk": { - "compression": { - "enabled": True, - "algorithm": "zstd", - } - } - } - """ - - enabled: NotRequired[bool] - """Enable table compression. Default: False. - - When True, adapters that support table-level compression will apply it - during table creation. - """ - - algorithm: NotRequired[str] - """Compression algorithm name. Backend-specific. - - Examples: - - PostgreSQL (with TOAST): 'pglz', 'lz4' (PG14+) - - MySQL/InnoDB: 'zlib' - - Oracle: 'basic', 'oltp', 'query_high', 'archive_high' - - DuckDB: 'zstd', 'snappy' - - When omitted, the backend default is used. - """ - - level: NotRequired[int] - """Compression level (where supported). Higher levels trade CPU for space savings. - - Valid ranges depend on the algorithm and backend. - """ - - -class ADKSqliteOptimizationConfig(TypedDict): - """SQLite-specific PRAGMA optimization settings. - - Controls SQLite performance tuning parameters applied at connection time. - These settings are ignored by non-SQLite adapters. - - Example: - extension_config={ - "adk": { - "sqlite_optimization": { - "cache_size": -64000, - "mmap_size": 31457280, - "journal_size_limit": 67108864, - } - } - } - """ - - cache_size: NotRequired[int] - """SQLite page cache size. Default: -64000 (64 MB, negative means KiB). - - Larger caches reduce disk I/O for read-heavy workloads. - Negative values specify size in KiB; positive values specify page count. - """ - - mmap_size: NotRequired[int] - """SQLite memory-mapped I/O size in bytes. Default: 31457280 (30 MB). - - Enables memory-mapped I/O for faster reads. Set to 0 to disable. - """ - - journal_size_limit: NotRequired[int] - """SQLite journal file size limit in bytes. Default: 67108864 (64 MB). - - Limits the size of the WAL or rollback journal file. - Prevents unbounded journal growth in write-heavy workloads. - """ - - -ADKOptimizationMode: TypeAlias = Literal["auto", "enable", "disable"] -"""Tri-state optimization control used by ADK capability negotiation.""" - - -class ADKSchemaConfig(TypedDict): - """Shared ADK schema naming and migration controls.""" - - session_table: NotRequired[str] - events_table: NotRequired[str] - memory_table: NotRequired[str] - artifact_table: NotRequired[str] - app_state_table: NotRequired[str] - user_state_table: NotRequired[str] - metadata_table: NotRequired[str] - owner_id_column: NotRequired[str] - schema_version: NotRequired[int] - payload_versions: NotRequired["ADKPayloadVersionsConfig"] - include_sessions_migration: NotRequired[bool] - include_memory_migration: NotRequired[bool] - include_artifact_migration: NotRequired[bool] - - -class ADKSearchConfig(TypedDict): - """Shared ADK search configuration.""" - - strategy: NotRequired[Literal["auto", "like", "fts", "vector"]] - use_fts: NotRequired[bool] - language: NotRequired[str] - max_results: NotRequired[int] - - -class ADKPayloadVersionsConfig(TypedDict): - """Shared ADK payload version pinning.""" - - event: NotRequired[int] - state: NotRequired[int] - memory: NotRequired[int] - artifact: NotRequired[int] - - -class ADKIndexingConfig(TypedDict): - """Shared ADK index lifecycle controls.""" - - generated_columns: NotRequired[ADKOptimizationMode] - covering_indexes: NotRequired[ADKOptimizationMode] - search_indexes: NotRequired[ADKOptimizationMode] - json_indexes: NotRequired[ADKOptimizationMode] - vector_indexes: NotRequired[ADKOptimizationMode] - - -class ADKTableOptionsConfig(TypedDict): - """Shared ADK table and index option attachment points.""" - - sessions: NotRequired[str] - events: NotRequired[str] - memory: NotRequired[str] - artifacts: NotRequired[str] - app_states: NotRequired[str] - user_states: NotRequired[str] - metadata: NotRequired[str] - expires_index: NotRequired[str] - - -class ADKLifecycleConfig(TypedDict): - """Shared ADK lifecycle controls for backend DDL chapters.""" - - partitioning: NotRequired[ADKPartitionConfig] - retention: NotRequired[ADKRetentionConfig] - indexing: NotRequired[ADKIndexingConfig] - compression: NotRequired[ADKCompressionConfig] - table_options: NotRequired[ADKTableOptionsConfig] - - -class ADKCapabilityConfig(TypedDict): - """Shared ADK capability detection overrides.""" - - overrides: NotRequired[dict[str, ADKOptimizationMode]] - - -class ADKMemoryConfig(TypedDict): - """Shared ADK memory configuration.""" - - enabled: NotRequired[bool] - table: NotRequired[str] - max_results: NotRequired[int] - search: NotRequired[ADKSearchConfig] - - -class ADKArtifactConfig(TypedDict): - """Shared ADK artifact configuration.""" - - table: NotRequired[str] - storage_uri: NotRequired[str] - - -class ADKOptimizationConfig(TypedDict): - """Shared ADK data-model optimization controls.""" - - generated_columns: NotRequired[ADKOptimizationMode] - null_encoded_empty_state: NotRequired[ADKOptimizationMode] - skip_noop_session_update: NotRequired[ADKOptimizationMode] - append_only_event_partitioning: NotRequired[ADKOptimizationMode] - covering_indexes: NotRequired[ADKOptimizationMode] - duckdb_struct_events: NotRequired[ADKOptimizationMode] - spanner_commit_timestamp_pk_suffix: NotRequired[ADKOptimizationMode] - alloydb_columnar_autopromote: NotRequired[ADKOptimizationMode] - - -class ADKOracleConfig(TypedDict): - """Oracle-specific ADK capability settings.""" - - in_memory: NotRequired[bool] - session_table_options: NotRequired[str] - events_table_options: NotRequired[str] - memory_table_options: NotRequired[str] - compression: NotRequired[ADKCompressionConfig] - partitioning: NotRequired[ADKPartitionConfig] - - -class ADKSpannerConfig(TypedDict): - """Spanner-specific ADK capability settings.""" - - shard_count: NotRequired[int] - interleave_events_in_sessions: NotRequired[bool] - session_table_options: NotRequired[str] - events_table_options: NotRequired[str] - memory_table_options: NotRequired[str] - - -class ADKADBCConfig(TypedDict): - """ADBC-specific ADK capability settings.""" - - dialect: NotRequired[str] - table_options: NotRequired[str] - - -class ADKBigQueryConfig(TypedDict): - """BigQuery-specific ADK capability settings.""" - - dataset: NotRequired[str] - partition_expiration_days: NotRequired[int] - clustering_fields: NotRequired[tuple[str, ...] | list[str]] - table_options: NotRequired[str] - - class ADKConfig(TypedDict): - """Configuration options for ADK session and memory store extension. + """Configuration options for ADK session, memory, and artifact storage. - All fields are optional with sensible defaults. Use in extension_config["adk"]: - - Configuration supports three deployment scenarios: - 1. SQLSpec manages everything (runtime + migrations) - 2. SQLSpec runtime only (external migration tools like Alembic/Flyway) - 3. Selective features (sessions OR memory, not both) - - Example: - from sqlspec.adapters.asyncpg import AsyncpgConfig - - config = AsyncpgConfig( - connection_config={"dsn": "postgresql://localhost/mydb"}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "memory_table": "my_memories", - "memory_use_fts": True, - "owner_id_column": "tenant_id INTEGER REFERENCES tenants(id)" - } - } - ) - - Notes: - This TypedDict provides type safety for extension config but is not required. - You can use plain dicts as well. + Use in ``extension_config["adk"]``. All fields are optional. Adapters read only + the keys they understand and ignore the rest; adapter-specific knobs (``in_memory``, + ``shard_count``, ``*_table_options``, ``partitioning``, ``compression``, ``retention``) + follow the same pattern as ``EventsConfig.in_memory``. """ - schema: NotRequired[ADKSchemaConfig] - """Shared schema naming, owner binding, and migration controls.""" - - memory: NotRequired[ADKMemoryConfig] - """Shared memory enablement and result-limit controls.""" - - search: NotRequired[ADKSearchConfig] - """Shared search strategy and language controls.""" - - payloads: NotRequired[ADKPayloadVersionsConfig] - """Shared payload version pins.""" - - lifecycle: NotRequired[ADKLifecycleConfig] - """Shared lifecycle controls for partitioning, retention, indexing, compression, and table options.""" - - capabilities: NotRequired[ADKCapabilityConfig] - """Shared detected capability overrides.""" - - artifact: NotRequired[ADKArtifactConfig] - """Shared artifact metadata and storage URI controls.""" - - optimizations: NotRequired[ADKOptimizationConfig] - """Shared optimization negotiation controls.""" - - oracle: NotRequired[ADKOracleConfig] - """Oracle-specific ADK capability settings.""" - - spanner: NotRequired[ADKSpannerConfig] - """Spanner-specific ADK capability settings.""" - - adbc: NotRequired[ADKADBCConfig] - """ADBC-specific ADK capability settings.""" - - bigquery: NotRequired[ADKBigQueryConfig] - """BigQuery-specific ADK capability settings.""" - enable_sessions: NotRequired[bool] - """Enable session store at runtime. Default: True. - - When False: session service unavailable, session store operations disabled. - Independent of migration control - can use externally-managed tables. - """ + """Enable session store at runtime. Defaults to True.""" enable_memory: NotRequired[bool] - """Enable memory store at runtime. Default: True. - - When False: memory service unavailable, memory store operations disabled. - Independent of migration control - can use externally-managed tables. - """ + """Enable memory store at runtime. Defaults to True.""" include_sessions_migration: NotRequired[bool] - """Include session tables in SQLSpec migrations. Default: True. - - When False: session migration DDL skipped (use external migration tools). - Decoupled from enable_sessions - allows external table management with SQLSpec runtime. - """ + """Include session tables in SQLSpec migrations. Defaults to True.""" include_memory_migration: NotRequired[bool] - """Include memory tables in SQLSpec migrations. Default: True. - - When False: memory migration DDL skipped (use external migration tools). - Decoupled from enable_memory - allows external table management with SQLSpec runtime. - """ + """Include memory tables in SQLSpec migrations. Defaults to True.""" session_table: NotRequired[str] - """Name of the sessions table. Default: 'adk_session' - - Examples: - "agent_sessions" - "my_app_sessions" - "tenant_acme_sessions" - """ + """Sessions table name. Defaults to ``"adk_session"``.""" events_table: NotRequired[str] - """Name of the events table. Default: 'adk_event' - - Examples: - "agent_events" - "my_app_events" - "tenant_acme_events" - """ + """Events table name. Defaults to ``"adk_event"``.""" app_state_table: NotRequired[str] - """Name of the app-scoped state table. Default: 'adk_app_state'.""" + """App-scoped state table name. Defaults to ``"adk_app_state"``.""" user_state_table: NotRequired[str] - """Name of the user-scoped state table. Default: 'adk_user_state'.""" + """User-scoped state table name. Defaults to ``"adk_user_state"``.""" metadata_table: NotRequired[str] - """Name of the internal ADK metadata table. Default: 'adk_internal_metadata'.""" + """Internal ADK metadata table name. Defaults to ``"adk_metadata"``.""" memory_table: NotRequired[str] - """Name of the memory entries table. Default: 'adk_memory_entries' - - Examples: - "agent_memories" - "my_app_memories" - "tenant_acme_memories" - """ + """Memory entries table name. Defaults to ``"adk_memory"``.""" artifact_table: NotRequired[str] - """Name of the artifact versions table. Default: 'adk_artifact_versions' - - Examples: - "agent_artifacts" - "my_app_artifact_versions" - """ + """Artifact versions table name. Defaults to ``"adk_artifact"``.""" artifact_storage_uri: NotRequired[str] - """Base URI for artifact content storage. + """Storage backend reference for artifact binary content. - Points to a ``sqlspec/storage/`` backend where artifact binary content - is stored. Can be a direct URI (``s3://bucket/path``, ``file:///path``) - or a registered alias in the storage registry. + Forwarded to :meth:`sqlspec.storage.StorageRegistry.get`. Accepts either: - Examples: - "s3://my-bucket/adk-artifacts/" - "file:///var/data/artifacts/" - "gcs://my-gcs-bucket/artifacts/" + - A direct URI: ``"s3://my-bucket/adk/"``, ``"file:///var/data/adk/"``, ``"gs://bucket/adk/"``. + - A registered alias: ``"artifacts"`` — register via + ``storage_registry.register_alias("artifacts", "s3://my-bucket/adk/", ...)`` to keep + backend-specific kwargs (credentials, base_path, fsspec/obstore overrides) out of the + ADK config. """ memory_use_fts: NotRequired[bool] - """Enable full-text search when supported. Default: False. - - When True, adapters will use their native FTS capabilities where available: - - PostgreSQL: to_tsvector/to_tsquery with GIN index - - SQLite: FTS5 virtual table - - DuckDB: FTS extension with match_bm25 - - Oracle: CONTAINS() with CTXSYS.CONTEXT index - - Spanner: TOKENIZE_FULLTEXT with search index - - MySQL: MATCH...AGAINST with FULLTEXT index - - When False, adapters use simple LIKE/ILIKE queries (works without indexes). - """ + """Use the backend's native full-text search when supported. Defaults to False.""" memory_max_results: NotRequired[int] - """Maximum number of results for memory search queries. Default: 20. - - Limits the number of memory entries returned by search_memory(). - Can be overridden per-query via the limit parameter. - """ + """Default ``limit`` for ``search_memory``. Defaults to 20.""" owner_id_column: NotRequired[str] - """Optional owner ID column definition to link sessions/memories to a user, tenant, team, or other entity. - - Format: "column_name TYPE [NOT NULL] REFERENCES table(column) [options...]" + """Optional owner-ID column DDL for multi-tenancy. - The entire definition is passed through to DDL verbatim. We only parse - the column name (first word) for use in INSERT/SELECT statements. + Appended to the ``adk_session`` and ``adk_memory`` tables only. Other ADK tables + skip it by design: ``adk_event`` cascades through its sessions FK, ``adk_app_state`` + and ``adk_user_state`` are already scoped through their primary keys, and + ``adk_metadata`` is internal kv. BigQuery is the lone exception — its analytics- + replica DDL ignores ``owner_id_column`` entirely; use the BigQuery + ``clustering_fields`` / ``partitioning`` knobs for tenant-scoped layout instead. - This column is added to both session and memory tables for consistent - multi-tenant isolation. - - Supports: - - Foreign key constraints: REFERENCES table(column) - - Nullable or NOT NULL - - CASCADE options: ON DELETE CASCADE, ON UPDATE CASCADE - - Dialect-specific options (DEFERRABLE, ENABLE VALIDATE, etc.) - - Plain columns without FK (just extra column storage) - - Examples: - PostgreSQL with UUID FK: - "account_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE" - - MySQL with BIGINT FK: - "user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE RESTRICT" - - Oracle with NUMBER FK: - "user_id NUMBER(10) REFERENCES users(id) ENABLE VALIDATE" - - SQLite with INTEGER FK: - "tenant_id INTEGER NOT NULL REFERENCES tenants(id)" - - Nullable FK (optional relationship): - "workspace_id UUID REFERENCES workspaces(id) ON DELETE SET NULL" - - No FK (just extra column): - "organization_name VARCHAR(128) NOT NULL" - - Deferred constraint (PostgreSQL): - "user_id UUID REFERENCES users(id) DEFERRABLE INITIALLY DEFERRED" - - Notes: - - Column name (first word) is extracted for INSERT/SELECT queries - - Rest of definition is passed through to CREATE TABLE DDL - - Database validates the DDL syntax (fail-fast on errors) - - Works with all database dialects (PostgreSQL, MySQL, SQLite, Oracle, etc.) + Format: ``"column_name TYPE [NOT NULL] [REFERENCES table(col) ...]"``. The full + fragment is forwarded to ``CREATE TABLE`` verbatim; only the first token (column + name) is parsed for use in INSERT/SELECT statements. """ in_memory: NotRequired[bool] - """Enable in-memory table storage (Oracle-specific). Default: False. - - When enabled, tables are created with the INMEMORY clause for Oracle Database, - which stores table data in columnar format in memory for faster query performance. - - This is an Oracle-specific feature that requires: - - Oracle Database 12.1.0.2 or higher - - Database In-Memory option license (Enterprise Edition) - - Sufficient INMEMORY_SIZE configured in the database instance - - Other database adapters ignore this setting. - - Examples: - Oracle with in-memory enabled: - config = OracleAsyncConfig( - connection_config={"dsn": "oracle://..."}, - extension_config={ - "adk": { - "in_memory": True - } - } - ) - - Notes: - - Improves query performance for analytics (10-100x faster) - - Tables created with INMEMORY clause - - Requires Oracle Database In-Memory option license - - Ignored by non-Oracle adapters - """ + """Enable Oracle's ``INMEMORY`` clause for ADK tables. Ignored by other adapters.""" shard_count: NotRequired[int] - """Optional hash shard count for session/event tables to reduce hotspotting. - - When set (>1), adapters that support computed shard columns will create a - generated shard_id using MOD(FARM_FINGERPRINT(primary_key), shard_count) and - include it in the primary key and filters. Ignored by adapters that do not - support computed shards. - """ + """Hash shard count for adapters that support computed shard columns (Spanner).""" session_table_options: NotRequired[str] - """Adapter-specific table OPTIONS/clauses for the sessions table. - - Passed verbatim when supported (e.g., Spanner columnar/tiered storage). Ignored by - adapters without table OPTIONS support. - """ + """Verbatim ``OPTIONS`` clause for the sessions table (Oracle, Spanner). Ignored elsewhere.""" events_table_options: NotRequired[str] - """Adapter-specific table OPTIONS/clauses for the events table.""" + """Verbatim ``OPTIONS`` clause for the events table (Oracle, Spanner). Ignored elsewhere.""" memory_table_options: NotRequired[str] - """Adapter-specific table OPTIONS/clauses for the memory table.""" + """Verbatim ``OPTIONS`` clause for the memory table (Oracle, Spanner). Ignored elsewhere.""" expires_index_options: NotRequired[str] - """Adapter-specific options for the expires/index used in ADK stores.""" + """Verbatim options for the expires index (Spanner). Ignored elsewhere.""" - # --- Capability-based configuration (Chapter 2: schema-capability-config) --- + partitioning: NotRequired["dict[str, Any]"] + """Table partitioning configuration. Consumed by Oracle (hash/range only).""" - fts_language: NotRequired[str] - """Language configuration for full-text search indexing. Default: 'english'. + compression: NotRequired["dict[str, Any]"] + """Table compression configuration. Consumed by Oracle.""" - Controls the language dictionary/stemmer used by FTS implementations: - - PostgreSQL: to_tsvector/to_tsquery language parameter - - SQLite FTS5: tokenizer language for unicode61/porter - - MySQL: FULLTEXT parser language (with ngram for CJK on 5.7.6+) - - Oracle: CTXSYS.CONTEXT lexer language - - Spanner: TOKENIZE_FULLTEXT language parameter - - DuckDB: FTS stemmer language - - Only takes effect when ``memory_use_fts`` is True. - - Common values: 'english', 'simple', 'german', 'french', 'spanish', - 'portuguese', 'italian', 'dutch', 'russian', 'chinese', 'japanese', 'korean'. - - Notes: - Available languages vary by backend. Backends that do not support the - specified language will fall back to 'simple' or 'english'. - """ - - schema_version: NotRequired[int] - """Explicit schema version for ADK tables. Default: None (auto-detect). - - When set, locks the ADK schema to a specific version. This is useful for: - - Preventing automatic schema upgrades in production - - Pinning to a known-good schema during testing - - Coordinating schema changes across multiple application instances - - When None, the ADK extension auto-detects the current schema version - and applies any pending upgrades during initialization. - - Notes: - Schema versions are monotonically increasing integers managed by - the ADK extension migration system. Setting this to a version - lower than the current database schema will raise a configuration - error at startup. - """ - - partitioning: NotRequired[ADKPartitionConfig] - """Table partitioning configuration. Default: None (no partitioning). - - Controls how ADK tables are partitioned for improved query performance - and data management at scale. See ``ADKPartitionConfig`` for options. - - Supported by: PostgreSQL, MySQL 8+, Oracle, Spanner. - Ignored by: SQLite, DuckDB. - """ - - retention: NotRequired[ADKRetentionConfig] - """Data retention and TTL configuration. Default: None (no automatic cleanup). - - Controls automatic expiry and cleanup of old session, event, and memory data. - See ``ADKRetentionConfig`` for options. - - Backends with native TTL (CockroachDB, Spanner) use database-level enforcement. - Others fall back to application-level sweep queries. - """ - - compression: NotRequired[ADKCompressionConfig] - """Table compression configuration. Default: None (no compression). - - Controls table-level compression for ADK tables. - See ``ADKCompressionConfig`` for options. - """ - - sqlite_optimization: NotRequired[ADKSqliteOptimizationConfig] - """SQLite-specific PRAGMA optimization settings. Default: None (SQLite defaults). - - Controls SQLite performance tuning parameters. Ignored by non-SQLite adapters. - See ``ADKSqliteOptimizationConfig`` for options. - """ + retention: NotRequired["dict[str, Any]"] + """Data retention / TTL configuration. Consumed by Spanner and BigQuery.""" class EventsConfig(TypedDict): diff --git a/sqlspec/extensions/adk/_capabilities.py b/sqlspec/extensions/adk/_capabilities.py deleted file mode 100644 index 3cf181703..000000000 --- a/sqlspec/extensions/adk/_capabilities.py +++ /dev/null @@ -1,101 +0,0 @@ -"""ADK capability detection and override resolution.""" - -from collections.abc import Mapping -from dataclasses import dataclass, field -from typing import Final, Literal, TypeAlias, cast - -from sqlspec.exceptions import ImproperConfigurationError - -__all__ = ( - "ADKCapabilityDecision", - "ADKCapabilityPlan", - "normalize_adk_capability_overrides", - "resolve_adk_capability_plan", -) - -ADKCapabilityMode: TypeAlias = Literal["auto", "enable", "disable"] -ADKCapabilitySource: TypeAlias = Literal["default", "detected", "override"] - -ADK_CAPABILITY_MODES: Final[frozenset[str]] = frozenset({"auto", "enable", "disable"}) - - -@dataclass(frozen=True, slots=True) -class ADKCapabilityDecision: - """Resolved decision for one ADK capability.""" - - feature: str - detected: bool | None - override: ADKCapabilityMode - enabled: bool - source: ADKCapabilitySource - reason: str | None = None - - -@dataclass(frozen=True, slots=True) -class ADKCapabilityPlan: - """Resolved ADK capability decisions keyed by feature name.""" - - decisions: dict[str, ADKCapabilityDecision] = field(default_factory=dict) - - def enabled_features(self) -> frozenset[str]: - """Return enabled feature names.""" - - return frozenset(feature for feature, decision in self.decisions.items() if decision.enabled) - - -def normalize_adk_capability_overrides(overrides: Mapping[str, object] | None = None) -> dict[str, ADKCapabilityMode]: - """Normalize capability overrides from ADK config.""" - - normalized: dict[str, ADKCapabilityMode] = {} - for feature, value in (overrides or {}).items(): - if value not in ADK_CAPABILITY_MODES: - msg = f"Unsupported ADK capability override {value!r} for {feature}; expected auto, enable, or disable" - raise ImproperConfigurationError(msg) - normalized[str(feature)] = cast("ADKCapabilityMode", value) - return normalized - - -def resolve_adk_capability_plan( - detected_features: Mapping[str, bool | None], overrides: Mapping[str, object] | None = None -) -> ADKCapabilityPlan: - """Resolve detected ADK capabilities with user overrides.""" - - normalized_overrides = normalize_adk_capability_overrides(overrides) - decisions: dict[str, ADKCapabilityDecision] = {} - for feature in sorted(set(detected_features) | set(normalized_overrides)): - detected = detected_features.get(feature) - override = normalized_overrides.get(feature, "auto") - decisions[feature] = _resolve_capability_decision(feature, detected, override) - return ADKCapabilityPlan(decisions=decisions) - - -def _resolve_capability_decision( - feature: str, detected: bool | None, override: ADKCapabilityMode -) -> ADKCapabilityDecision: - if override == "disable": - return ADKCapabilityDecision( - feature=feature, detected=detected, override=override, enabled=False, source="override" - ) - if override == "enable": - if detected is False: - msg = f"ADK capability {feature!r} was forced enabled but detection reported it as unsupported" - raise ImproperConfigurationError(msg) - return ADKCapabilityDecision( - feature=feature, detected=detected, override=override, enabled=True, source="override" - ) - if detected is True: - return ADKCapabilityDecision( - feature=feature, detected=detected, override=override, enabled=True, source="detected" - ) - if detected is False: - return ADKCapabilityDecision( - feature=feature, detected=detected, override=override, enabled=False, source="detected" - ) - return ADKCapabilityDecision( - feature=feature, - detected=detected, - override=override, - enabled=False, - source="default", - reason="capability was not detected", - ) diff --git a/sqlspec/extensions/adk/_config_utils.py b/sqlspec/extensions/adk/_config_utils.py index d723479dc..2d0716a35 100644 --- a/sqlspec/extensions/adk/_config_utils.py +++ b/sqlspec/extensions/adk/_config_utils.py @@ -6,9 +6,6 @@ from typing_extensions import NotRequired, TypedDict from sqlspec.exceptions import SQLSpecError -from sqlspec.extensions.adk._capabilities import ADKCapabilityMode, normalize_adk_capability_overrides -from sqlspec.extensions.adk._lifecycle import ADKLifecyclePlan, resolve_adk_lifecycle_plan -from sqlspec.extensions.adk._versioning import ADKVersionPlan, resolve_adk_version_plan from sqlspec.utils.module_loader import import_string __all__ = ( @@ -17,13 +14,10 @@ "_ADKSessionStoreConfig", "_get_adk_adapter_store_class", "_get_adk_artifact_store_config", - "_get_adk_capability_overrides", "_get_adk_config_from_extension", - "_get_adk_lifecycle_plan", "_get_adk_memory_migration_store_class", "_get_adk_memory_store_config", "_get_adk_session_store_config", - "_get_adk_version_plan", "_is_adk_memory_migration_enabled", "_validate_adk_store_registration", ) @@ -72,40 +66,18 @@ def _get_adk_config_from_extension(config: _ADKConfigSource) -> dict[str, Any]: return dict(cast("dict[str, Any]", config.extension_config.get("adk", {}))) -def _get_adk_config_section(adk_config: dict[str, Any], name: str) -> dict[str, Any]: - """Return a mutable nested ADK config section.""" - - value = adk_config.get(name) - return dict(cast("dict[str, Any]", value)) if isinstance(value, dict) else {} - - -def _get_first_value(*values: Any, default: Any = None) -> Any: - """Return the first non-None value.""" - - for value in values: - if value is not None: - return value - return default - - def _get_adk_session_store_config(config: _ADKConfigSource) -> _ADKSessionStoreConfig: """Return normalized session store table settings.""" adk_config = _get_adk_config_from_extension(config) - schema_config = _get_adk_config_section(adk_config, "schema") - session_table = _get_first_value(schema_config.get("session_table"), adk_config.get("session_table")) - events_table = _get_first_value(schema_config.get("events_table"), adk_config.get("events_table")) - app_state_table = _get_first_value(schema_config.get("app_state_table"), adk_config.get("app_state_table")) - user_state_table = _get_first_value(schema_config.get("user_state_table"), adk_config.get("user_state_table")) - metadata_table = _get_first_value(schema_config.get("metadata_table"), adk_config.get("metadata_table")) result: _ADKSessionStoreConfig = { - "session_table": str(session_table) if session_table is not None else "adk_session", - "events_table": str(events_table) if events_table is not None else "adk_event", - "app_state_table": str(app_state_table) if app_state_table is not None else "adk_app_state", - "user_state_table": str(user_state_table) if user_state_table is not None else "adk_user_state", - "metadata_table": str(metadata_table) if metadata_table is not None else "adk_internal_metadata", + "session_table": str(adk_config.get("session_table") or "adk_session"), + "events_table": str(adk_config.get("events_table") or "adk_event"), + "app_state_table": str(adk_config.get("app_state_table") or "adk_app_state"), + "user_state_table": str(adk_config.get("user_state_table") or "adk_user_state"), + "metadata_table": str(adk_config.get("metadata_table") or "adk_metadata"), } - owner_id = _get_first_value(schema_config.get("owner_id_column"), adk_config.get("owner_id_column")) + owner_id = adk_config.get("owner_id_column") if owner_id is not None: result["owner_id_column"] = cast("str", owner_id) return result @@ -115,34 +87,15 @@ def _get_adk_memory_store_config(config: _ADKConfigSource) -> _ADKMemoryStoreCon """Return normalized memory store settings.""" adk_config = _get_adk_config_from_extension(config) - schema_config = _get_adk_config_section(adk_config, "schema") - memory_config = _get_adk_config_section(adk_config, "memory") - search_config = _get_adk_config_section(adk_config, "search") - nested_memory_search_config = _get_adk_config_section(memory_config, "search") - enable_memory = _get_first_value(memory_config.get("enabled"), adk_config.get("enable_memory")) - memory_table = _get_first_value( - memory_config.get("table"), schema_config.get("memory_table"), adk_config.get("memory_table") - ) - use_fts = _get_first_value( - nested_memory_search_config.get("use_fts"), - search_config.get("use_fts"), - memory_config.get("use_fts"), - adk_config.get("memory_use_fts"), - ) - max_results = _get_first_value( - memory_config.get("max_results"), - nested_memory_search_config.get("max_results"), - search_config.get("max_results"), - adk_config.get("memory_max_results"), - ) - + enable_memory = adk_config.get("enable_memory") + max_results = adk_config.get("memory_max_results") result: _ADKMemoryStoreConfig = { "enable_memory": bool(enable_memory) if enable_memory is not None else True, - "memory_table": str(memory_table) if memory_table is not None else "adk_memory_entries", - "use_fts": bool(use_fts) if use_fts is not None else False, - "max_results": int(max_results) if type(max_results) is int else 20, + "memory_table": str(adk_config.get("memory_table") or "adk_memory"), + "use_fts": bool(adk_config.get("memory_use_fts", False)), + "max_results": int(max_results) if isinstance(max_results, int) else 20, } - owner_id = _get_first_value(schema_config.get("owner_id_column"), adk_config.get("owner_id_column")) + owner_id = adk_config.get("owner_id_column") if owner_id is not None: result["owner_id_column"] = cast("str", owner_id) return result @@ -152,40 +105,15 @@ def _get_adk_artifact_store_config(config: _ADKConfigSource) -> _ADKArtifactStor """Return normalized artifact store settings.""" adk_config = _get_adk_config_from_extension(config) - schema_config = _get_adk_config_section(adk_config, "schema") - artifact_config = _get_adk_config_section(adk_config, "artifact") - artifact_table = _get_first_value( - artifact_config.get("table"), schema_config.get("artifact_table"), adk_config.get("artifact_table") - ) result: _ADKArtifactStoreConfig = { - "artifact_table": str(artifact_table) if artifact_table is not None else "adk_artifact_versions" + "artifact_table": str(adk_config.get("artifact_table") or "adk_artifact") } - storage_uri = _get_first_value(artifact_config.get("storage_uri"), adk_config.get("artifact_storage_uri")) + storage_uri = adk_config.get("artifact_storage_uri") if storage_uri is not None: result["storage_uri"] = str(storage_uri) return result -def _get_adk_version_plan(config: _ADKConfigSource) -> ADKVersionPlan: - """Return normalized ADK schema and payload version settings.""" - - return resolve_adk_version_plan(_get_adk_config_from_extension(config)) - - -def _get_adk_lifecycle_plan(config: _ADKConfigSource) -> ADKLifecyclePlan: - """Return normalized ADK lifecycle control settings.""" - - return resolve_adk_lifecycle_plan(_get_adk_config_from_extension(config)) - - -def _get_adk_capability_overrides(config: _ADKConfigSource) -> dict[str, ADKCapabilityMode]: - """Return normalized ADK capability overrides.""" - - adk_config = _get_adk_config_from_extension(config) - capabilities_config = _get_adk_config_section(adk_config, "capabilities") - return normalize_adk_capability_overrides(_get_adk_config_section(capabilities_config, "overrides")) - - def _resolve_adk_store_path(config: Any, store_suffix: str) -> str: """Return the adapter-specific ADK store import path.""" @@ -254,16 +182,11 @@ def _is_adk_memory_migration_enabled(config: Any) -> bool: """Return whether ADK memory DDL should be included for this config.""" adk_config = _get_adk_config_from_extension(cast("_ADKConfigSource", config)) - schema_config = _get_adk_config_section(adk_config, "schema") - memory_config = _get_adk_config_section(adk_config, "memory") - include_memory = _get_first_value( - schema_config.get("include_memory_migration"), - memory_config.get("include_migration"), - adk_config.get("include_memory_migration"), - ) + include_memory = adk_config.get("include_memory_migration") if include_memory is not None: return bool(include_memory) - return bool(_get_first_value(memory_config.get("enabled"), adk_config.get("enable_memory"), default=True)) + enable_memory = adk_config.get("enable_memory") + return bool(enable_memory) if enable_memory is not None else True def _validate_adk_store_registration(config: Any) -> None: diff --git a/sqlspec/extensions/adk/_lifecycle.py b/sqlspec/extensions/adk/_lifecycle.py deleted file mode 100644 index 9f497e47b..000000000 --- a/sqlspec/extensions/adk/_lifecycle.py +++ /dev/null @@ -1,105 +0,0 @@ -"""ADK lifecycle control resolution.""" - -from collections.abc import Mapping -from dataclasses import dataclass, field -from typing import Final, Literal, TypeAlias, cast - -from sqlspec.exceptions import ImproperConfigurationError - -__all__ = ("ADKLifecyclePlan", "resolve_adk_lifecycle_plan", "validate_adk_lifecycle_plan") - -ADKLifecycleMode: TypeAlias = Literal["auto", "enable", "disable"] - -ADK_INDEXING_CONTROLS: Final[tuple[str, ...]] = ( - "generated_columns", - "covering_indexes", - "search_indexes", - "json_indexes", - "vector_indexes", -) -ADK_LIFECYCLE_MODES: Final[frozenset[str]] = frozenset({"auto", "enable", "disable"}) - - -@dataclass(frozen=True, slots=True) -class ADKLifecyclePlan: - """Resolved ADK lifecycle controls used by backend DDL chapters.""" - - partitioning: dict[str, object] | None = None - retention: dict[str, object] | None = None - indexing: dict[str, ADKLifecycleMode] = field(default_factory=dict) - compression: dict[str, object] | None = None - table_options: dict[str, str] = field(default_factory=dict) - - -def resolve_adk_lifecycle_plan(adk_config: Mapping[str, object] | None = None) -> ADKLifecyclePlan: - """Resolve lifecycle controls from ADK extension config.""" - - config = adk_config or {} - lifecycle_config = _mapping(config.get("lifecycle")) - plan = ADKLifecyclePlan( - partitioning=_optional_mapping(_first_value(lifecycle_config.get("partitioning"), config.get("partitioning"))), - retention=_optional_mapping(_first_value(lifecycle_config.get("retention"), config.get("retention"))), - indexing=_resolve_indexing_config(config, lifecycle_config), - compression=_optional_mapping(_first_value(lifecycle_config.get("compression"), config.get("compression"))), - table_options=_resolve_table_options(config, lifecycle_config), - ) - validate_adk_lifecycle_plan(plan) - return plan - - -def validate_adk_lifecycle_plan(plan: ADKLifecyclePlan) -> None: - """Validate lifecycle controls that have shared semantics.""" - - for key, value in plan.indexing.items(): - if value not in ADK_LIFECYCLE_MODES: - msg = f"Unsupported ADK lifecycle indexing mode {value!r} for {key}; expected auto, enable, or disable" - raise ImproperConfigurationError(msg) - - -def _resolve_indexing_config( - config: Mapping[str, object], lifecycle_config: Mapping[str, object] -) -> dict[str, ADKLifecycleMode]: - lifecycle_indexing = _mapping(lifecycle_config.get("indexing")) - top_level_indexing = _mapping(config.get("indexing")) - optimizations = _mapping(config.get("optimizations")) - resolved: dict[str, ADKLifecycleMode] = {} - for key in ADK_INDEXING_CONTROLS: - value = _first_value(lifecycle_indexing.get(key), top_level_indexing.get(key), optimizations.get(key), "auto") - resolved[key] = _indexing_mode(key, value) - return resolved - - -def _resolve_table_options(config: Mapping[str, object], lifecycle_config: Mapping[str, object]) -> dict[str, str]: - flat_options = { - "sessions": config.get("session_table_options"), - "events": config.get("events_table_options"), - "memory": config.get("memory_table_options"), - "expires_index": config.get("expires_index_options"), - } - resolved = {key: str(value) for key, value in flat_options.items() if value is not None} - resolved.update({key: str(value) for key, value in _mapping(lifecycle_config.get("table_options")).items()}) - return resolved - - -def _mapping(value: object) -> Mapping[str, object]: - return value if isinstance(value, Mapping) else {} - - -def _optional_mapping(value: object | None) -> dict[str, object] | None: - if isinstance(value, Mapping): - return dict(value) - return None - - -def _first_value(*values: object) -> object | None: - for value in values: - if value is not None: - return value - return None - - -def _indexing_mode(key: str, value: object) -> ADKLifecycleMode: - if value in ADK_LIFECYCLE_MODES: - return cast("ADKLifecycleMode", value) - msg = f"Unsupported ADK lifecycle indexing mode {value!r} for {key}; expected auto, enable, or disable" - raise ImproperConfigurationError(msg) diff --git a/sqlspec/extensions/adk/_versioning.py b/sqlspec/extensions/adk/_versioning.py deleted file mode 100644 index 4375bafdf..000000000 --- a/sqlspec/extensions/adk/_versioning.py +++ /dev/null @@ -1,142 +0,0 @@ -"""ADK schema and payload version planning.""" - -from collections.abc import Mapping -from dataclasses import dataclass -from typing import Final, Literal, TypeAlias - -from sqlspec.exceptions import ImproperConfigurationError - -__all__ = ("ADKVersionPlan", "resolve_adk_version_plan", "validate_adk_version_plan") - - -ADKPayloadKind: TypeAlias = Literal["event", "state", "memory", "artifact"] - -ADK_SCHEMA_VERSION: Final = 1 -ADK_EVENT_PAYLOAD_VERSION: Final = 1 -ADK_STATE_PAYLOAD_VERSION: Final = 1 -ADK_MEMORY_PAYLOAD_VERSION: Final = 1 -ADK_ARTIFACT_PAYLOAD_VERSION: Final = 1 - -ADK_SCHEMA_VERSION_KEY: Final = "schema_version" -ADK_PAYLOAD_VERSION_KEYS: Final[dict[ADKPayloadKind, str]] = { - "event": "sqlspec.adk.payload.event", - "state": "sqlspec.adk.payload.state", - "memory": "sqlspec.adk.payload.memory", - "artifact": "sqlspec.adk.payload.artifact", -} - -SUPPORTED_ADK_SCHEMA_VERSIONS: Final[frozenset[int]] = frozenset({ADK_SCHEMA_VERSION}) -SUPPORTED_ADK_PAYLOAD_VERSIONS: Final[dict[ADKPayloadKind, frozenset[int]]] = { - "event": frozenset({ADK_EVENT_PAYLOAD_VERSION}), - "state": frozenset({ADK_STATE_PAYLOAD_VERSION}), - "memory": frozenset({ADK_MEMORY_PAYLOAD_VERSION}), - "artifact": frozenset({ADK_ARTIFACT_PAYLOAD_VERSION}), -} - - -@dataclass(frozen=True, slots=True) -class ADKVersionPlan: - """Resolved ADK schema and payload version contract.""" - - schema_version: int = ADK_SCHEMA_VERSION - event_payload_version: int = ADK_EVENT_PAYLOAD_VERSION - state_payload_version: int = ADK_STATE_PAYLOAD_VERSION - memory_payload_version: int = ADK_MEMORY_PAYLOAD_VERSION - artifact_payload_version: int = ADK_ARTIFACT_PAYLOAD_VERSION - - def payload_versions(self) -> dict[ADKPayloadKind, int]: - """Return payload versions keyed by payload kind.""" - - return { - "event": self.event_payload_version, - "state": self.state_payload_version, - "memory": self.memory_payload_version, - "artifact": self.artifact_payload_version, - } - - def metadata_items(self) -> tuple[tuple[str, str], ...]: - """Return deterministic metadata rows for the ADK metadata table.""" - - return ( - (ADK_SCHEMA_VERSION_KEY, str(self.schema_version)), - (ADK_PAYLOAD_VERSION_KEYS["event"], str(self.event_payload_version)), - (ADK_PAYLOAD_VERSION_KEYS["state"], str(self.state_payload_version)), - (ADK_PAYLOAD_VERSION_KEYS["memory"], str(self.memory_payload_version)), - (ADK_PAYLOAD_VERSION_KEYS["artifact"], str(self.artifact_payload_version)), - ) - - -def resolve_adk_version_plan(adk_config: Mapping[str, object] | None = None) -> ADKVersionPlan: - """Resolve the configured ADK schema and payload versions.""" - - config = adk_config or {} - schema_config = _mapping(config.get("schema")) - schema_payloads = _mapping(schema_config.get("payload_versions")) - top_level_payloads = _mapping(config.get("payloads")) - plan = ADKVersionPlan( - schema_version=_version_value( - _first_value(schema_config.get("schema_version"), config.get("schema_version")), - default=ADK_SCHEMA_VERSION, - label="schema.schema_version", - ), - event_payload_version=_payload_version("event", schema_payloads, top_level_payloads), - state_payload_version=_payload_version("state", schema_payloads, top_level_payloads), - memory_payload_version=_payload_version("memory", schema_payloads, top_level_payloads), - artifact_payload_version=_payload_version("artifact", schema_payloads, top_level_payloads), - ) - validate_adk_version_plan(plan) - return plan - - -def validate_adk_version_plan(plan: ADKVersionPlan) -> None: - """Validate that a resolved ADK version plan is supported.""" - - if plan.schema_version not in SUPPORTED_ADK_SCHEMA_VERSIONS: - _raise_unsupported_version( - "schema", plan.schema_version, sorted(SUPPORTED_ADK_SCHEMA_VERSIONS), "schema.schema_version" - ) - for payload_kind, payload_version in plan.payload_versions().items(): - supported = SUPPORTED_ADK_PAYLOAD_VERSIONS[payload_kind] - if payload_version not in supported: - _raise_unsupported_version(payload_kind, payload_version, sorted(supported), "schema.payload_versions") - - -def _payload_version( - payload_kind: ADKPayloadKind, schema_payloads: Mapping[str, object], top_level_payloads: Mapping[str, object] -) -> int: - default_versions = { - "event": ADK_EVENT_PAYLOAD_VERSION, - "state": ADK_STATE_PAYLOAD_VERSION, - "memory": ADK_MEMORY_PAYLOAD_VERSION, - "artifact": ADK_ARTIFACT_PAYLOAD_VERSION, - } - return _version_value( - _first_value(schema_payloads.get(payload_kind), top_level_payloads.get(payload_kind)), - default=default_versions[payload_kind], - label=f"schema.payload_versions.{payload_kind}", - ) - - -def _mapping(value: object) -> Mapping[str, object]: - return value if isinstance(value, Mapping) else {} - - -def _first_value(*values: object) -> object | None: - for value in values: - if value is not None: - return value - return None - - -def _version_value(value: object | None, *, default: int, label: str) -> int: - if value is None: - return default - if type(value) is not int: - msg = f"ADK {label} must be an integer version, got {value!r}" - raise ImproperConfigurationError(msg) - return value - - -def _raise_unsupported_version(kind: str, version: int, supported: list[int], label: str) -> None: - msg = f"Unsupported ADK {kind} version {version!r} from {label}; supported versions: {supported}" - raise ImproperConfigurationError(msg) diff --git a/sqlspec/extensions/adk/artifact/__init__.py b/sqlspec/extensions/adk/artifact/__init__.py index 72255a877..c6f03a90d 100644 --- a/sqlspec/extensions/adk/artifact/__init__.py +++ b/sqlspec/extensions/adk/artifact/__init__.py @@ -18,7 +18,7 @@ connection_config={"dsn": "postgresql://..."}, extension_config={ "adk": { - "artifact_table": "adk_artifact_versions", + "artifact_table": "adk_artifact", } } ) diff --git a/sqlspec/extensions/adk/artifact/store.py b/sqlspec/extensions/adk/artifact/store.py index a9cd7a353..100d404e0 100644 --- a/sqlspec/extensions/adk/artifact/store.py +++ b/sqlspec/extensions/adk/artifact/store.py @@ -42,7 +42,7 @@ class BaseAsyncADKArtifactStore(ABC, Generic[ConfigT]): Notes: Configuration is read from config.extension_config["adk"]: - - artifact_table: Artifact versions table name (default: "adk_artifact_versions") + - artifact_table: Artifact versions table name (default: "adk_artifact") """ __slots__ = ("_artifact_table", "_config") diff --git a/sqlspec/extensions/adk/memory/__init__.py b/sqlspec/extensions/adk/memory/__init__.py index 40798dffa..66e4a6748 100644 --- a/sqlspec/extensions/adk/memory/__init__.py +++ b/sqlspec/extensions/adk/memory/__init__.py @@ -21,7 +21,7 @@ connection_config={"dsn": "postgresql://..."}, extension_config={ "adk": { - "memory_table": "adk_memory_entries", + "memory_table": "adk_memory", "memory_use_fts": True, "memory_max_results": 50, } diff --git a/sqlspec/extensions/adk/memory/presets.py b/sqlspec/extensions/adk/memory/presets.py index 8d2de10d1..78874627b 100644 --- a/sqlspec/extensions/adk/memory/presets.py +++ b/sqlspec/extensions/adk/memory/presets.py @@ -2,7 +2,7 @@ Resolution order (highest priority first): -1. ``embedding_dimension`` explicit override on ``ADKMemoryConfig`` +1. ``embedding_dimension`` explicit override on the ``memory`` block 2. ``embedding_preset`` name resolved against :data:`EMBEDDING_PRESETS` 3. Raise ``ADKConfigError`` with the preset table referenced in the error message. @@ -17,6 +17,7 @@ from sqlspec.exceptions import ImproperConfigurationError __all__ = ( + "DEFAULT_EMBEDDING_PRESET", "EMBEDDING_PRESETS", "EmbeddingPreset", "ResolvedEmbeddingConfig", @@ -119,18 +120,23 @@ def register_embedding_preset(name: str, preset: EmbeddingPreset) -> None: EMBEDDING_PRESETS[name] = preset +DEFAULT_EMBEDDING_PRESET: Final[str] = "gemini-embedding-002" + + def resolve_embedding_config(memory_config: "dict[str, object] | None") -> ResolvedEmbeddingConfig: - """Resolve an :class:`ResolvedEmbeddingConfig` from an ``ADKMemoryConfig`` mapping. + """Resolve a :class:`ResolvedEmbeddingConfig` from an ADK memory config mapping. Args: memory_config: ``extension_config["adk"]["memory"]`` mapping. Returns: - Resolved embedding configuration. + Resolved embedding configuration. Falls back to + :data:`DEFAULT_EMBEDDING_PRESET` when neither ``embedding_dimension`` nor + ``embedding_preset`` is supplied. Raises: - ImproperConfigurationError: When neither ``embedding_dimension`` nor - ``embedding_preset`` is supplied, or the named preset is unknown. + ImproperConfigurationError: When the named preset is unknown or + ``embedding_dimension`` is not an int. """ config = memory_config or {} preset_name = config.get("embedding_preset") @@ -158,29 +164,22 @@ def resolve_embedding_config(memory_config: "dict[str, object] | None") -> Resol preset=preset, ) - if preset is not None: - return ResolvedEmbeddingConfig( - dim=preset.dim, - precision=str(explicit_precision) if explicit_precision else preset.precision, - normalize=bool(explicit_normalize) if explicit_normalize is not None else preset.normalize, - source="embedding_preset", - preset=preset, - ) - - _raise_unresolved() - return None + if preset is None: + preset = EMBEDDING_PRESETS[DEFAULT_EMBEDDING_PRESET] + source = "default" + else: + source = "embedding_preset" + + return ResolvedEmbeddingConfig( + dim=preset.dim, + precision=str(explicit_precision) if explicit_precision else preset.precision, + normalize=bool(explicit_normalize) if explicit_normalize is not None else preset.normalize, + source=source, + preset=preset, + ) def _raise_unknown_preset(name: str) -> NoReturn: available = ", ".join(sorted(EMBEDDING_PRESETS)) msg = f"Unknown embedding preset {name!r}. Available presets: {available}" raise ImproperConfigurationError(msg) - - -def _raise_unresolved() -> NoReturn: - available = ", ".join(sorted(EMBEDDING_PRESETS)) - msg = ( - "ADK memory store requires either embedding_dimension or embedding_preset " - f"to be set in extension_config['adk']['memory']. Available presets: {available}" - ) - raise ImproperConfigurationError(msg) diff --git a/sqlspec/extensions/adk/memory/service.py b/sqlspec/extensions/adk/memory/service.py index 7f106df01..58de30eb9 100644 --- a/sqlspec/extensions/adk/memory/service.py +++ b/sqlspec/extensions/adk/memory/service.py @@ -49,7 +49,7 @@ class SQLSpecMemoryService(BaseMemoryService): connection_config={"dsn": "postgresql://..."}, extension_config={ "adk": { - "memory_table": "adk_memory_entries", + "memory_table": "adk_memory", "memory_use_fts": True, } } diff --git a/sqlspec/extensions/adk/memory/store.py b/sqlspec/extensions/adk/memory/store.py index f7a7f3091..93e4c31ae 100644 --- a/sqlspec/extensions/adk/memory/store.py +++ b/sqlspec/extensions/adk/memory/store.py @@ -64,7 +64,7 @@ class BaseAsyncADKMemoryStore(ABC, Generic[ConfigT]): Notes: Configuration is read from config.extension_config["adk"]: - - memory_table: Memory table name (default: "adk_memory_entries") + - memory_table: Memory table name (default: "adk_memory") - memory_use_fts: Enable full-text search when supported (default: False) - memory_max_results: Max search results (default: 20) - owner_id_column: Optional owner FK column DDL (default: None) @@ -89,7 +89,7 @@ def __init__(self, config: ConfigT) -> None: Notes: Reads configuration from config.extension_config["adk"]: - - memory_table: Memory table name (default: "adk_memory_entries") + - memory_table: Memory table name (default: "adk_memory") - memory_use_fts: Enable full-text search when supported (default: False) - memory_max_results: Max search results (default: 20) - owner_id_column: Optional owner FK column DDL (default: None) diff --git a/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py b/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py index b288d8ab3..8381a35c5 100644 --- a/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +++ b/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py @@ -1,183 +1,22 @@ -"""Create ADK session, events, and memory tables migration using store DDL definitions.""" +"""No-op migration: superseded by 0002_reset_adk_tables. -import logging -from typing import TYPE_CHECKING, NoReturn, cast +This file used to create the legacy ADK ``sessions`` / ``events`` tables. The +ADK 2.0 clean break replaces that schema in 0002. 0001 is retained as a no-op +so installs that already applied it keep their tracking-table row; fresh +installs run it as a no-op and proceed to 0002. +""" -from sqlspec.exceptions import SQLSpecError -from sqlspec.extensions.adk._config_utils import ( - _get_adk_adapter_store_class, - _get_adk_memory_migration_store_class, - _is_adk_memory_migration_enabled, -) -from sqlspec.utils.logging import get_logger, log_with_context +from typing import TYPE_CHECKING if TYPE_CHECKING: - from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore - from sqlspec.extensions.adk.store import BaseAsyncADKStore from sqlspec.migrations.context import MigrationContext -logger = get_logger("sqlspec.migrations.adk.tables") - __all__ = ("down", "up") -def _get_store_class(context: "MigrationContext | None") -> "type[BaseAsyncADKStore]": - """Get the appropriate store class based on the config's module path. - - Args: - context: Migration context containing config. - - Returns: - Store class matching the config's adapter. - - Notes: - Dynamically imports the store class from the config's module path. - For example, AsyncpgConfig at 'sqlspec.adapters.asyncpg.config' - maps to AsyncpgADKStore at 'sqlspec.adapters.asyncpg.adk.AsyncpgADKStore'. - """ - if not context or not context.config: - _raise_missing_config() - - return cast("type[BaseAsyncADKStore]", _get_adk_adapter_store_class(context.config, "ADKStore")) - - -def _get_memory_store_class(context: "MigrationContext | None") -> "type[BaseAsyncADKMemoryStore] | None": - """Get the appropriate memory store class based on the config's module path. - - Args: - context: Migration context containing config. - - Returns: - Memory store class matching the config's adapter, or None if not available. - - Notes: - Dynamically imports the memory store class from the config's module path. - For example, AsyncpgConfig at 'sqlspec.adapters.asyncpg.config' - maps to AsyncpgADKMemoryStore at 'sqlspec.adapters.asyncpg.adk.AsyncpgADKMemoryStore'. - """ - if not context or not context.config: - return None - - store_class = _get_adk_memory_migration_store_class(context.config) - if store_class is None: - log_with_context(logger, logging.DEBUG, "adk.migration.memory_store.missing") - return None - return cast("type[BaseAsyncADKMemoryStore]", store_class) - - -def _is_memory_enabled(context: "MigrationContext | None") -> bool: - """Check if memory migration is enabled in the config. - - Args: - context: Migration context containing config. - - Returns: - True if memory migration should be included, False otherwise. - - Notes: - Checks config.extension_config["adk"]["include_memory_migration"]. - Defaults to True if not specified and enable_memory is True. - """ - if not context or not context.config: - return False - - return _is_adk_memory_migration_enabled(context.config) - - -def _raise_missing_config() -> NoReturn: - """Raise error when migration context has no config. - - Raises: - SQLSpecError: Always raised. - """ - msg = "Migration context must have a config to determine store class" - raise SQLSpecError(msg) - - async def up(context: "MigrationContext | None" = None) -> "list[str]": - """Create the ADK session, events, and memory tables using store DDL definitions. - - This migration delegates to the appropriate store class to generate - dialect-specific DDL. The store classes contain the single source of - truth for table schemas. - - Args: - context: Migration context containing config. - - Returns: - List of SQL statements to execute for upgrade. - - Notes: - Configuration is read from context.config.extension_config["adk"]. - Supports custom table names and optional owner_id_column for linking - sessions to owner tables (users, tenants, teams, etc.). - Memory table is included if enable_memory or include_memory_migration is True. - """ - if context is None or context.config is None: - _raise_missing_config() - - store_class = _get_store_class(context) - store_instance = store_class(config=context.config) - - statements = [ - await store_instance._get_create_sessions_table_sql(), # pyright: ignore[reportPrivateUsage] - await store_instance._get_create_events_table_sql(), # pyright: ignore[reportPrivateUsage] - await store_instance._get_create_app_states_table_sql(), # pyright: ignore[reportPrivateUsage] - await store_instance._get_create_user_states_table_sql(), # pyright: ignore[reportPrivateUsage] - await store_instance._get_create_metadata_table_sql(), # pyright: ignore[reportPrivateUsage] - await store_instance._get_seed_metadata_sql(), # pyright: ignore[reportPrivateUsage] - ] - - if _is_memory_enabled(context): - memory_store_class = _get_memory_store_class(context) - if memory_store_class is not None: - memory_store = memory_store_class(config=context.config) - memory_sql = await memory_store._get_create_memory_table_sql() # pyright: ignore[reportPrivateUsage] - if isinstance(memory_sql, list): - statements.extend(memory_sql) - else: - statements.append(memory_sql) - log_with_context( - logger, logging.DEBUG, "adk.migration.memory.include", table_name=memory_store.memory_table - ) - - return statements + return [] async def down(context: "MigrationContext | None" = None) -> "list[str]": - """Drop the ADK session, events, and memory tables using store DDL definitions. - - This migration delegates to the appropriate store class to generate - dialect-specific DROP statements. The store classes contain the single - source of truth for table schemas. - - Args: - context: Migration context containing config. - - Returns: - List of SQL statements to execute for downgrade. - - Notes: - Configuration is read from context.config.extension_config["adk"]. - Memory table is included if enable_memory or include_memory_migration is True. - """ - if context is None or context.config is None: - _raise_missing_config() - - statements: list[str] = [] - - if _is_memory_enabled(context): - memory_store_class = _get_memory_store_class(context) - if memory_store_class is not None: - memory_store = memory_store_class(config=context.config) - memory_drop_stmts = memory_store._get_drop_memory_table_sql() # pyright: ignore[reportPrivateUsage] - statements.extend(memory_drop_stmts) - log_with_context( - logger, logging.DEBUG, "adk.migration.memory.drop.include", table_name=memory_store.memory_table - ) - - store_class = _get_store_class(context) - store_instance = store_class(config=context.config) - statements.extend(store_instance._get_drop_tables_sql()) # pyright: ignore[reportPrivateUsage] - - return statements + return [] diff --git a/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py b/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py new file mode 100644 index 000000000..ff7e76c10 --- /dev/null +++ b/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py @@ -0,0 +1,118 @@ +"""Reset ADK schema to the 2.0 clean-break shape. + +Unconditionally drops any legacy ADK tables (sessions, events, app_states, +user_states, metadata, memory) then creates the new schema and seeds the +internal metadata row. The memory table is dropped unconditionally so users +moving from ``enable_memory=True`` to ``enable_memory=False`` get cleanup; it +is recreated only when memory is enabled for the current config. +""" + +import logging +from typing import TYPE_CHECKING, NoReturn, cast + +from sqlspec.exceptions import SQLSpecError +from sqlspec.extensions.adk._config_utils import ( + _get_adk_adapter_store_class, + _get_adk_memory_migration_store_class, + _is_adk_memory_migration_enabled, +) +from sqlspec.utils.logging import get_logger, log_with_context + +if TYPE_CHECKING: + from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore + from sqlspec.extensions.adk.store import BaseAsyncADKStore + from sqlspec.migrations.context import MigrationContext + +logger = get_logger("sqlspec.migrations.adk.reset") + +__all__ = ("down", "up") + + +def _raise_missing_config() -> NoReturn: + msg = "Migration context must have a config to determine store class" + raise SQLSpecError(msg) + + +def _get_store_class(context: "MigrationContext | None") -> "type[BaseAsyncADKStore]": + if not context or not context.config: + _raise_missing_config() + return cast("type[BaseAsyncADKStore]", _get_adk_adapter_store_class(context.config, "ADKStore")) + + +def _get_memory_store_class(context: "MigrationContext | None") -> "type[BaseAsyncADKMemoryStore] | None": + if not context or not context.config: + return None + store_class = _get_adk_memory_migration_store_class(context.config) + if store_class is None: + log_with_context(logger, logging.DEBUG, "adk.migration.reset.memory_store.missing") + return None + return cast("type[BaseAsyncADKMemoryStore]", store_class) + + +def _is_memory_enabled(context: "MigrationContext | None") -> bool: + if not context or not context.config: + return False + return _is_adk_memory_migration_enabled(context.config) + + +async def up(context: "MigrationContext | None" = None) -> "list[str]": + if context is None or context.config is None: + _raise_missing_config() + + store_class = _get_store_class(context) + store_instance = store_class(config=context.config) + + statements: list[str] = [] + + memory_store_class = _get_memory_store_class(context) + if memory_store_class is not None: + memory_store = memory_store_class(config=context.config) + statements.extend(memory_store._get_drop_memory_table_sql()) # pyright: ignore[reportPrivateUsage] + log_with_context( + logger, logging.DEBUG, "adk.migration.reset.memory.drop", table_name=memory_store.memory_table + ) + + statements.extend(store_instance._get_drop_tables_sql()) # pyright: ignore[reportPrivateUsage] + + statements.extend( + [ + await store_instance._get_create_sessions_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_create_events_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_create_app_states_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_create_user_states_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_create_metadata_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_seed_metadata_sql(), # pyright: ignore[reportPrivateUsage] + ] + ) + + if _is_memory_enabled(context) and memory_store_class is not None: + memory_store = memory_store_class(config=context.config) + memory_sql = await memory_store._get_create_memory_table_sql() # pyright: ignore[reportPrivateUsage] + if isinstance(memory_sql, list): + statements.extend(memory_sql) + else: + statements.append(memory_sql) + log_with_context( + logger, logging.DEBUG, "adk.migration.reset.memory.create", table_name=memory_store.memory_table + ) + + return statements + + +async def down(context: "MigrationContext | None" = None) -> "list[str]": + if context is None or context.config is None: + _raise_missing_config() + + statements: list[str] = [] + + if _is_memory_enabled(context): + memory_store_class = _get_memory_store_class(context) + if memory_store_class is not None: + memory_store = memory_store_class(config=context.config) + statements.extend(memory_store._get_drop_memory_table_sql()) # pyright: ignore[reportPrivateUsage] + + store_class = _get_store_class(context) + store_instance = store_class(config=context.config) + statements.extend(store_instance._get_drop_tables_sql()) # pyright: ignore[reportPrivateUsage] + + return statements diff --git a/sqlspec/extensions/adk/service.py b/sqlspec/extensions/adk/service.py index 75891f1a8..1c32d1ea8 100644 --- a/sqlspec/extensions/adk/service.py +++ b/sqlspec/extensions/adk/service.py @@ -283,16 +283,21 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": ) raise ValueError(msg) - # --- Persist event and all scoped state atomically --- - updated_record = await self._store.append_event_and_update_state( - event_record=event_record, - app_name=session.app_name, - user_id=session.user_id, - session_id=session.id, - state=session_state, - app_state=app_state or None, - user_state=user_state or None, - ) + state_delta = (event.actions.state_delta if event.actions else None) or {} + + if not state_delta: + await self._store.append_event(event_record) + updated_record = current_record + else: + updated_record = await self._store.append_event_and_update_state( + event_record=event_record, + app_name=session.app_name, + user_id=session.user_id, + session_id=session.id, + state=session_state, + app_state=app_state or None, + user_state=user_state or None, + ) updated_record["state"] = merge_scoped_state(updated_record["state"], app_state, user_state) # Use the returned record directly — saves a round-trip vs a follow-up get_session(). diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index 00c0e7057..0fdfbcd54 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -79,7 +79,7 @@ class BaseAsyncADKStore(ABC, Generic[ConfigT]): - events_table: Events table name (default: "adk_event") - app_state_table: App-scoped state table name (default: "adk_app_state") - user_state_table: User-scoped state table name (default: "adk_user_state") - - metadata_table: Internal metadata table name (default: "adk_internal_metadata") + - metadata_table: Internal metadata table name (default: "adk_metadata") - owner_id_column: Optional owner FK column DDL (default: None) """ @@ -106,7 +106,7 @@ def __init__(self, config: ConfigT) -> None: - events_table: Events table name (default: "adk_event") - app_state_table: App-scoped state table name (default: "adk_app_state") - user_state_table: User-scoped state table name (default: "adk_user_state") - - metadata_table: Internal metadata table name (default: "adk_internal_metadata") + - metadata_table: Internal metadata table name (default: "adk_metadata") - owner_id_column: Optional owner FK column DDL (default: None) """ self._config = config diff --git a/tests/integration/adapters/_adk_contract_helpers.py b/tests/integration/adapters/_adk_contract_helpers.py index e064fb260..0fe3549b5 100644 --- a/tests/integration/adapters/_adk_contract_helpers.py +++ b/tests/integration/adapters/_adk_contract_helpers.py @@ -673,3 +673,5 @@ async def assert_memory_store_contract(store: MemoryStore, *, marker: str) -> No fresh_results = await store.search_entries("fresh", app_name, user_id, limit=10) assert len(fresh_results) == 1 assert fresh_results[0]["event_id"] == fresh_record["event_id"] + + diff --git a/tests/integration/adapters/aiosqlite/extensions/adk/test_memory_store.py b/tests/integration/adapters/aiosqlite/extensions/adk/test_memory_store.py index 0bc65b2e0..e17e19524 100644 --- a/tests/integration/adapters/aiosqlite/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/aiosqlite/extensions/adk/test_memory_store.py @@ -101,7 +101,7 @@ async def test_aiosqlite_memory_store_disabled_lifecycle() -> None: async with config.provide_connection() as conn: cursor = await conn.execute( - "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", ("adk_memory_entries",) + "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", ("adk_memory",) ) row = await cursor.fetchone() diff --git a/tests/integration/adapters/asyncpg/extensions/adk/conftest.py b/tests/integration/adapters/asyncpg/extensions/adk/conftest.py index 45e88839f..2d54c7233 100644 --- a/tests/integration/adapters/asyncpg/extensions/adk/conftest.py +++ b/tests/integration/adapters/asyncpg/extensions/adk/conftest.py @@ -47,7 +47,7 @@ async def asyncpg_adk_store(postgres_service: "PostgresService") -> "AsyncGenera await conn.execute("DROP TABLE IF EXISTS adk_session CASCADE") await conn.execute("DROP TABLE IF EXISTS adk_user_state CASCADE") await conn.execute("DROP TABLE IF EXISTS adk_app_state CASCADE") - await conn.execute("DROP TABLE IF EXISTS adk_internal_metadata CASCADE") + await conn.execute("DROP TABLE IF EXISTS adk_metadata CASCADE") finally: if config.connection_instance: await config.close_pool() diff --git a/tests/integration/adapters/cockroach_asyncpg/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/cockroach_asyncpg/extensions/adk/test_scoped_state_contract.py index e9652fc14..4a1d6a1f9 100644 --- a/tests/integration/adapters/cockroach_asyncpg/extensions/adk/test_scoped_state_contract.py +++ b/tests/integration/adapters/cockroach_asyncpg/extensions/adk/test_scoped_state_contract.py @@ -65,20 +65,12 @@ async def test_cockroach_asyncpg_session_table_lifecycle_contract( await assert_session_table_lifecycle_contract(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") -@pytest.mark.xfail( - reason="sqlspec-7rbl: cockroach_asyncpg multi-statement tx hits multiple_active_portals limitation; tracked separately", - strict=False, -) async def test_cockroach_asyncpg_session_scoped_state_contract( cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, ) -> None: await assert_session_scoped_state_contract(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") -@pytest.mark.xfail( - reason="sqlspec-7rbl: cockroach_asyncpg multi-statement tx hits multiple_active_portals limitation; tracked separately", - strict=False, -) async def test_cockroach_asyncpg_session_atomic_scoped_write_contract( cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, ) -> None: @@ -97,20 +89,12 @@ async def test_cockroach_asyncpg_session_empty_state_roundtrip( await assert_session_empty_state_roundtrip(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") -@pytest.mark.xfail( - reason="sqlspec-7rbl: cockroach_asyncpg multi-statement tx hits multiple_active_portals limitation; tracked separately", - strict=False, -) async def test_cockroach_asyncpg_session_sibling_app_isolation( cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, ) -> None: await assert_session_sibling_app_isolation(cockroach_asyncpg_adk_store, marker="cockroach-asyncpg") -@pytest.mark.xfail( - reason="sqlspec-7rbl: cockroach_asyncpg multi-statement tx hits multiple_active_portals limitation; tracked separately", - strict=False, -) async def test_cockroach_asyncpg_session_sibling_user_isolation( cockroach_asyncpg_adk_store: CockroachAsyncpgADKStore, ) -> None: diff --git a/tests/integration/adapters/cockroach_psycopg/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/cockroach_psycopg/extensions/adk/test_scoped_state_contract.py index 118fbe202..f9707265e 100644 --- a/tests/integration/adapters/cockroach_psycopg/extensions/adk/test_scoped_state_contract.py +++ b/tests/integration/adapters/cockroach_psycopg/extensions/adk/test_scoped_state_contract.py @@ -19,15 +19,7 @@ assert_session_temp_state_not_persisted, ) -pytestmark = [ - pytest.mark.xdist_group("cockroachdb"), - pytest.mark.cockroachdb, - pytest.mark.integration, - pytest.mark.xfail( - reason="sqlspec-xqnf: CockroachPsycopgAsyncADKStore inherits psycopg dict-row bug; tracked separately", - strict=False, - ), -] +pytestmark = [pytest.mark.xdist_group("cockroachdb"), pytest.mark.cockroachdb, pytest.mark.integration] @pytest.fixture(scope="session") diff --git a/tests/integration/adapters/psqlpy/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/psqlpy/extensions/adk/test_scoped_state_contract.py index 1d99674e5..6d979866e 100644 --- a/tests/integration/adapters/psqlpy/extensions/adk/test_scoped_state_contract.py +++ b/tests/integration/adapters/psqlpy/extensions/adk/test_scoped_state_contract.py @@ -86,8 +86,5 @@ async def test_psqlpy_session_sibling_user_isolation(psqlpy_adk_store: PsqlpyADK await assert_session_sibling_user_isolation(psqlpy_adk_store, marker="psqlpy") -@pytest.mark.xfail( - reason="sqlspec-8cyp: PsqlpyADKStore.get_session does not catch UndefinedTable; tracked separately", strict=False -) async def test_psqlpy_session_table_lifecycle_contract(psqlpy_adk_store: PsqlpyADKStore) -> None: await assert_session_table_lifecycle_contract(psqlpy_adk_store, marker="psqlpy") diff --git a/tests/integration/adapters/psycopg/extensions/adk/test_scoped_state_contract.py b/tests/integration/adapters/psycopg/extensions/adk/test_scoped_state_contract.py index 97788a66b..a014ec640 100644 --- a/tests/integration/adapters/psycopg/extensions/adk/test_scoped_state_contract.py +++ b/tests/integration/adapters/psycopg/extensions/adk/test_scoped_state_contract.py @@ -19,15 +19,7 @@ assert_session_temp_state_not_persisted, ) -pytestmark = [ - pytest.mark.xdist_group("postgres"), - pytest.mark.psycopg, - pytest.mark.integration, - pytest.mark.xfail( - reason="sqlspec-xqnf: PsycopgAsyncADKStore read paths return tuples instead of dicts; tracked separately", - strict=False, - ), -] +pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.psycopg, pytest.mark.integration] @pytest.fixture(scope="session") diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py index 19dc7fc96..e1edabb41 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py @@ -97,7 +97,7 @@ async def test_sqlite_memory_store_disabled_lifecycle() -> None: with config.provide_connection() as conn: cursor = conn.execute( - "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", ("adk_memory_entries",) + "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", ("adk_memory",) ) row = cursor.fetchone() diff --git a/tests/unit/adapters/test_bigquery_adk.py b/tests/unit/adapters/test_bigquery_adk.py index bbc78c5bd..a82e5a9db 100644 --- a/tests/unit/adapters/test_bigquery_adk.py +++ b/tests/unit/adapters/test_bigquery_adk.py @@ -30,7 +30,7 @@ def test_bigquery_adk_store_instantiates_with_defaults() -> None: assert store.events_table == "adk_event" assert store.app_state_table == "adk_app_state" assert store.user_state_table == "adk_user_state" - assert store.metadata_table == "adk_internal_metadata" + assert store.metadata_table == "adk_metadata" assert store._dataset_qualifier == "test_dataset." assert store._lookup_window_days == 30 assert store._require_partition_filter is True @@ -38,7 +38,7 @@ def test_bigquery_adk_store_instantiates_with_defaults() -> None: def test_bigquery_adk_store_honours_session_lookup_window() -> None: - """ADKBigQueryConfig.session_lookup_window_days is propagated.""" + """``bigquery.session_lookup_window_days`` from ``extension_config['adk']`` is propagated.""" store = _make_store({"bigquery": {"session_lookup_window_days": 7}}) assert store._lookup_window_days == 7 diff --git a/tests/unit/adapters/test_psycopg/test_adk_store.py b/tests/unit/adapters/test_psycopg/test_adk_store.py index 6f5ca672f..72801477a 100644 --- a/tests/unit/adapters/test_psycopg/test_adk_store.py +++ b/tests/unit/adapters/test_psycopg/test_adk_store.py @@ -38,7 +38,7 @@ def __enter__(self) -> Self: def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: return None - def cursor(self) -> _DummyCursor: + def cursor(self, row_factory: Any = None) -> _DummyCursor: return self._cursor def commit(self) -> None: diff --git a/tests/unit/adapters/test_spanner/test_adk_store.py b/tests/unit/adapters/test_spanner/test_adk_store.py index 8bbb116ee..0705d13ab 100644 --- a/tests/unit/adapters/test_spanner/test_adk_store.py +++ b/tests/unit/adapters/test_spanner/test_adk_store.py @@ -123,7 +123,7 @@ async def test_spanner_memory_insert_entries_writes_clean_break_record() -> None assert inserted == 1 statements = run_write.call_args.args[0] sql, params, _types = statements[0] - assert "INSERT INTO adk_memory_entries" in sql + assert "INSERT INTO adk_memory" in sql assert params["content_json"] == '{"text":"hello"}' assert params["metadata_json"] == '{"source":"unit"}' assert params["inserted_at"] is timestamp diff --git a/tests/unit/extensions/test_adk/test_capabilities.py b/tests/unit/extensions/test_adk/test_capabilities.py deleted file mode 100644 index e8b9fbf85..000000000 --- a/tests/unit/extensions/test_adk/test_capabilities.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Tests for ADK capability detection and override resolution.""" - -from typing import Any - -import pytest - -from sqlspec.config import ADKConfig -from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.extensions.adk._capabilities import resolve_adk_capability_plan -from sqlspec.extensions.adk._config_utils import _get_adk_capability_overrides - - -class _Config: - extension_config: dict[str, dict[str, Any]] - - def __init__(self, adk_config: dict[str, Any]) -> None: - self.extension_config = {"adk": adk_config} - - -def test_adk_config_declares_capabilities_section() -> None: - assert "capabilities" in ADKConfig.__annotations__ - - -def test_capability_plan_uses_detected_features_by_default() -> None: - plan = resolve_adk_capability_plan(detected_features={"supports_generated_columns": True, "supports_vector": False}) - - assert plan.decisions["supports_generated_columns"].enabled is True - assert plan.decisions["supports_generated_columns"].source == "detected" - assert plan.decisions["supports_vector"].enabled is False - assert plan.decisions["supports_vector"].source == "detected" - - -def test_disable_override_wins_over_detected_feature() -> None: - plan = resolve_adk_capability_plan( - detected_features={"supports_generated_columns": True}, overrides={"supports_generated_columns": "disable"} - ) - - decision = plan.decisions["supports_generated_columns"] - assert decision.enabled is False - assert decision.override == "disable" - assert decision.source == "override" - - -def test_enable_override_rejects_known_unsupported_feature() -> None: - with pytest.raises(ImproperConfigurationError, match="supports_vector"): - resolve_adk_capability_plan( - detected_features={"supports_vector": False}, overrides={"supports_vector": "enable"} - ) - - -def test_enable_override_can_force_unknown_detection_result() -> None: - plan = resolve_adk_capability_plan(detected_features={}, overrides={"supports_json_table": "enable"}) - - decision = plan.decisions["supports_json_table"] - assert decision.enabled is True - assert decision.detected is None - assert decision.source == "override" - - -def test_config_capability_overrides_are_normalized() -> None: - overrides = _get_adk_capability_overrides( - _Config({"capabilities": {"overrides": {"supports_generated_columns": "disable"}}}) - ) - - assert overrides == {"supports_generated_columns": "disable"} - - -def test_invalid_capability_override_raises_configuration_error() -> None: - with pytest.raises(ImproperConfigurationError): - _get_adk_capability_overrides( - _Config({"capabilities": {"overrides": {"supports_generated_columns": "sometimes"}}}) - ) diff --git a/tests/unit/extensions/test_adk/test_config_resolution.py b/tests/unit/extensions/test_adk/test_config_resolution.py index 7e84d2e15..87790e7ce 100644 --- a/tests/unit/extensions/test_adk/test_config_resolution.py +++ b/tests/unit/extensions/test_adk/test_config_resolution.py @@ -1,4 +1,4 @@ -"""Tests for ADK clean-break configuration resolution.""" +"""Tests for ADK flat-config resolution.""" from typing import Any @@ -18,23 +18,23 @@ def __init__(self, adk_config: dict[str, Any]) -> None: self.extension_config = {"adk": adk_config} -def test_adk_config_declares_nested_capability_sections() -> None: - expected = {"schema", "memory", "search", "artifact", "optimizations", "oracle", "spanner", "adbc", "bigquery"} +def test_adk_config_uses_flat_keys() -> None: + """ADKConfig is a flat TypedDict; no per-adapter or nested negotiation blocks.""" + annotations = set(ADKConfig.__annotations__) + expected_flat = {"session_table", "events_table", "memory_table", "artifact_table", "in_memory", "owner_id_column"} + forbidden_nested = {"schema", "lifecycle", "capabilities", "optimizations", "oracle", "spanner", "adbc", "bigquery"} + assert expected_flat <= annotations + assert annotations.isdisjoint(forbidden_nested) - assert expected <= set(ADKConfig.__annotations__) - -def test_nested_schema_config_resolves_all_adk_table_names() -> None: +def test_flat_schema_config_resolves_all_adk_table_names() -> None: config = _Config({ - "session_table": "flat_sessions", - "schema": { - "session_table": "agent_sessions", - "events_table": "agent_events", - "app_state_table": "agent_app_states", - "user_state_table": "agent_user_states", - "metadata_table": "agent_metadata", - "owner_id_column": "tenant_id UUID", - }, + "session_table": "agent_sessions", + "events_table": "agent_events", + "app_state_table": "agent_app_states", + "user_state_table": "agent_user_states", + "metadata_table": "agent_metadata", + "owner_id_column": "tenant_id UUID", }) resolved = _get_adk_session_store_config(config) @@ -49,11 +49,12 @@ def test_nested_schema_config_resolves_all_adk_table_names() -> None: } -def test_nested_memory_and_search_config_resolve_memory_store_settings() -> None: +def test_flat_memory_config_resolves_memory_store_settings() -> None: config = _Config({ - "enable_memory": True, - "memory": {"enabled": False, "table": "agent_memories", "max_results": 50}, - "search": {"use_fts": True, "language": "simple"}, + "enable_memory": False, + "memory_table": "agent_memories", + "memory_use_fts": True, + "memory_max_results": 50, }) resolved = _get_adk_memory_store_config(config) @@ -61,19 +62,23 @@ def test_nested_memory_and_search_config_resolve_memory_store_settings() -> None assert resolved == {"enable_memory": False, "memory_table": "agent_memories", "use_fts": True, "max_results": 50} -def test_nested_artifact_config_resolves_table_and_storage_uri() -> None: - config = _Config({ - "artifact_table": "flat_artifacts", - "artifact_storage_uri": "file:///flat", - "artifact": {"table": "agent_artifacts", "storage_uri": "s3://bucket/adk"}, - }) +def test_flat_artifact_config_resolves_table_and_storage_uri() -> None: + config = _Config({"artifact_table": "agent_artifacts", "artifact_storage_uri": "s3://bucket/adk"}) resolved = _get_adk_artifact_store_config(config) assert resolved == {"artifact_table": "agent_artifacts", "storage_uri": "s3://bucket/adk"} -def test_schema_include_memory_migration_overrides_runtime_memory_enablement() -> None: - config = _Config({"memory": {"enabled": True}, "schema": {"include_memory_migration": False}}) +def test_include_memory_migration_overrides_enable_memory() -> None: + config = _Config({"enable_memory": True, "include_memory_migration": False}) assert not _is_adk_memory_migration_enabled(config) + + +def test_include_memory_migration_defaults_to_enable_memory() -> None: + enabled = _Config({"enable_memory": True}) + disabled = _Config({"enable_memory": False}) + + assert _is_adk_memory_migration_enabled(enabled) is True + assert _is_adk_memory_migration_enabled(disabled) is False diff --git a/tests/unit/extensions/test_adk/test_converters.py b/tests/unit/extensions/test_adk/test_converters.py index f7d5018a6..06d670820 100644 --- a/tests/unit/extensions/test_adk/test_converters.py +++ b/tests/unit/extensions/test_adk/test_converters.py @@ -1,12 +1,4 @@ -"""Unit tests for ADK session/event converters and scoped state helpers. - -Tests the NEW contract specified in Chapter 1 of the ADK Clean-Break Overhaul: -- EventRecord has exactly 5 keys (session_id, invocation_id, author, timestamp, event_data) -- event_to_record takes only (event, session_id), not (event, session_id, app_name, user_id) -- record_to_event uses Event.model_validate for full round-trip fidelity -- filter_temp_state, split_scoped_state, merge_scoped_state for scoped state handling -- session_to_record strips temp: keys from state -""" +"""Unit tests for ADK session/event converters and scoped state helpers.""" import importlib.util from datetime import datetime, timezone diff --git a/tests/unit/extensions/test_adk/test_embedding_presets.py b/tests/unit/extensions/test_adk/test_embedding_presets.py index 71b6455cb..07060f977 100644 --- a/tests/unit/extensions/test_adk/test_embedding_presets.py +++ b/tests/unit/extensions/test_adk/test_embedding_presets.py @@ -55,9 +55,14 @@ def test_resolve_with_explicit_precision_and_normalize() -> None: assert resolved.normalize is False -def test_resolve_empty_config_raises_with_preset_table() -> None: - with pytest.raises(ImproperConfigurationError, match="gemini-embedding-002"): - resolve_embedding_config(None) +def test_resolve_empty_config_falls_back_to_default_gemini_preset() -> None: + resolved = resolve_embedding_config(None) + assert resolved.preset is not None + assert resolved.preset.name == "gemini-embedding-002" + assert resolved.source == "default" + assert resolved.dim == 1536 + assert resolved.precision == "float32" + assert resolved.normalize is True def test_resolve_unknown_preset_lists_available_presets() -> None: diff --git a/tests/unit/extensions/test_adk/test_lifecycle_config.py b/tests/unit/extensions/test_adk/test_lifecycle_config.py deleted file mode 100644 index 98585b3fc..000000000 --- a/tests/unit/extensions/test_adk/test_lifecycle_config.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Tests for ADK lifecycle control resolution.""" - -from typing import Any - -import pytest - -from sqlspec.config import ADKConfig -from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.extensions.adk._config_utils import _get_adk_lifecycle_plan - - -class _Config: - extension_config: dict[str, dict[str, Any]] - - def __init__(self, adk_config: dict[str, Any]) -> None: - self.extension_config = {"adk": adk_config} - - -def test_adk_config_declares_lifecycle_section() -> None: - assert "lifecycle" in ADKConfig.__annotations__ - - -def test_default_lifecycle_plan_sets_indexing_controls_to_auto() -> None: - plan = _get_adk_lifecycle_plan(_Config({})) - - assert plan.partitioning is None - assert plan.retention is None - assert plan.compression is None - assert plan.indexing == { - "generated_columns": "auto", - "covering_indexes": "auto", - "search_indexes": "auto", - "json_indexes": "auto", - "vector_indexes": "auto", - } - assert plan.table_options == {} - - -def test_nested_lifecycle_sections_override_flat_legacy_keys() -> None: - plan = _get_adk_lifecycle_plan( - _Config({ - "partitioning": {"strategy": "hash", "partition_count": 4}, - "retention": {"event_ttl_seconds": 60}, - "compression": {"enabled": False}, - "session_table_options": "flat-session-options", - "lifecycle": { - "partitioning": {"strategy": "range", "interval": "month"}, - "retention": {"event_ttl_seconds": 120}, - "indexing": {"generated_columns": "enable", "covering_indexes": "disable"}, - "compression": {"enabled": True, "algorithm": "zstd"}, - "table_options": {"sessions": "nested-session-options", "events": "nested-event-options"}, - }, - }) - ) - - assert plan.partitioning == {"strategy": "range", "interval": "month"} - assert plan.retention == {"event_ttl_seconds": 120} - assert plan.compression == {"enabled": True, "algorithm": "zstd"} - assert plan.indexing["generated_columns"] == "enable" - assert plan.indexing["covering_indexes"] == "disable" - assert plan.table_options == {"sessions": "nested-session-options", "events": "nested-event-options"} - - -def test_flat_table_options_are_normalized_when_lifecycle_options_are_absent() -> None: - plan = _get_adk_lifecycle_plan( - _Config({ - "session_table_options": "session-options", - "events_table_options": "event-options", - "memory_table_options": "memory-options", - "expires_index_options": "expires-options", - }) - ) - - assert plan.table_options == { - "sessions": "session-options", - "events": "event-options", - "memory": "memory-options", - "expires_index": "expires-options", - } - - -def test_invalid_lifecycle_indexing_mode_raises_configuration_error() -> None: - config = _Config({"lifecycle": {"indexing": {"generated_columns": "sometimes"}}}) - - with pytest.raises(ImproperConfigurationError): - _get_adk_lifecycle_plan(config) diff --git a/tests/unit/extensions/test_adk/test_migrations.py b/tests/unit/extensions/test_adk/test_migrations.py new file mode 100644 index 000000000..36977d31b --- /dev/null +++ b/tests/unit/extensions/test_adk/test_migrations.py @@ -0,0 +1,151 @@ +# pyright: reportPrivateUsage = false +"""Unit tests for the ADK clean-break cutover migrations. + +Covers ``sqlspec/extensions/adk/migrations/0001_create_adk_tables.py`` (no-op +after Revision 8) and ``sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py`` +(drop-and-recreate cutover). The reference adapter is asyncpg; the migration's +statement-set contract is the same regardless of which adapter resolves the +store class. +""" + +import importlib + +import pytest + +from sqlspec.adapters.asyncpg import AsyncpgConfig +from sqlspec.exceptions import SQLSpecError +from sqlspec.migrations.context import MigrationContext + +migration_0001 = importlib.import_module("sqlspec.extensions.adk.migrations.0001_create_adk_tables") +migration_0002 = importlib.import_module("sqlspec.extensions.adk.migrations.0002_reset_adk_tables") + + +def _build_config(adk: "dict[str, object] | None" = None) -> AsyncpgConfig: + return AsyncpgConfig(connection_config={"dsn": "postgresql://localhost/test"}, extension_config={"adk": adk or {}}) + + +def _build_context(adk: "dict[str, object] | None" = None) -> MigrationContext: + return MigrationContext(config=_build_config(adk)) + + +def _index_of(statements: "list[str]", needle: str) -> int: + for idx, statement in enumerate(statements): + if needle in statement: + return idx + msg = f"Expected to find {needle!r} in statements: {statements}" + raise AssertionError(msg) + + +async def test_0001_up_is_noop_with_context() -> None: + assert await migration_0001.up(_build_context()) == [] + + +async def test_0001_down_is_noop_with_context() -> None: + assert await migration_0001.down(_build_context()) == [] + + +async def test_0001_up_is_noop_without_context() -> None: + assert await migration_0001.up(None) == [] + + +async def test_0001_down_is_noop_without_context() -> None: + assert await migration_0001.down(None) == [] + + +async def test_0002_up_with_memory_enabled_emits_full_statement_set() -> None: + statements = await migration_0002.up(_build_context()) + + memory_drop_idx = _index_of(statements, "DROP TABLE IF EXISTS adk_memory") + metadata_drop_idx = _index_of(statements, "DROP TABLE IF EXISTS adk_metadata") + events_drop_idx = _index_of(statements, "DROP TABLE IF EXISTS adk_event") + session_drop_idx = _index_of(statements, "DROP TABLE IF EXISTS adk_session") + session_create_idx = _index_of(statements, "CREATE TABLE IF NOT EXISTS adk_session") + events_create_idx = _index_of(statements, "CREATE TABLE IF NOT EXISTS adk_event") + app_state_create_idx = _index_of(statements, "CREATE TABLE IF NOT EXISTS adk_app_state") + user_state_create_idx = _index_of(statements, "CREATE TABLE IF NOT EXISTS adk_user_state") + metadata_create_idx = _index_of(statements, "CREATE TABLE IF NOT EXISTS adk_metadata") + seed_idx = _index_of(statements, "INSERT INTO adk_metadata") + memory_create_idx = _index_of(statements, "CREATE TABLE IF NOT EXISTS adk_memory") + + assert memory_drop_idx < metadata_drop_idx, "memory drop must precede session-store drops" + assert metadata_drop_idx < events_drop_idx < session_drop_idx, "drops must be FK-safe" + assert session_drop_idx < session_create_idx, "creates must follow drops" + assert ( + session_create_idx + < events_create_idx + < app_state_create_idx + < user_state_create_idx + < metadata_create_idx + < seed_idx + ), "create order must satisfy FK and seed dependencies" + assert seed_idx < memory_create_idx, "memory create runs after the session-store cutover" + assert "schema_version" in statements[seed_idx] + + +async def test_0002_up_with_memory_disabled_drops_memory_table_but_skips_create() -> None: + statements = await migration_0002.up(_build_context({"enable_memory": False})) + + assert any("DROP TABLE IF EXISTS adk_memory" in stmt for stmt in statements), ( + "memory drop must be unconditional so enable_memory=True->False transitions clean up" + ) + assert all("CREATE TABLE IF NOT EXISTS adk_memory" not in stmt for stmt in statements), ( + "memory create must be skipped when memory is disabled" + ) + assert any("CREATE TABLE IF NOT EXISTS adk_session" in stmt for stmt in statements) + assert any("INSERT INTO adk_metadata" in stmt for stmt in statements) + + +async def test_0002_up_with_no_memory_store_class_skips_memory_branch_entirely(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(migration_0002, "_get_adk_memory_migration_store_class", lambda _config: None) + + statements = await migration_0002.up(_build_context()) + + assert all("memory" not in stmt.lower() for stmt in statements), ( + "no memory drop or create when the adapter ships no memory store" + ) + assert any("DROP TABLE IF EXISTS adk_session" in stmt for stmt in statements) + assert any("CREATE TABLE IF NOT EXISTS adk_session" in stmt for stmt in statements) + assert any("INSERT INTO adk_metadata" in stmt for stmt in statements) + + +async def test_0002_down_with_memory_enabled_drops_memory_then_new_tables() -> None: + statements = await migration_0002.down(_build_context()) + + memory_drop_idx = _index_of(statements, "DROP TABLE IF EXISTS adk_memory") + session_drop_idx = _index_of(statements, "DROP TABLE IF EXISTS adk_session") + metadata_drop_idx = _index_of(statements, "DROP TABLE IF EXISTS adk_metadata") + + assert memory_drop_idx < metadata_drop_idx < session_drop_idx, ( + "down() drops memory first, then new tables FK-safe (children before parents)" + ) + assert all("CREATE TABLE" not in stmt for stmt in statements) + + +async def test_0002_down_with_memory_disabled_drops_only_new_tables() -> None: + statements = await migration_0002.down(_build_context({"enable_memory": False})) + + assert all("adk_memory" not in stmt for stmt in statements), ( + "down() does not touch memory when memory is disabled for this config" + ) + assert any("DROP TABLE IF EXISTS adk_session" in stmt for stmt in statements) + assert any("DROP TABLE IF EXISTS adk_metadata" in stmt for stmt in statements) + + +async def test_0002_up_raises_when_context_missing() -> None: + with pytest.raises(SQLSpecError, match="Migration context must have a config"): + await migration_0002.up(None) + + +async def test_0002_down_raises_when_context_missing() -> None: + with pytest.raises(SQLSpecError, match="Migration context must have a config"): + await migration_0002.down(None) + + +async def test_0002_up_raises_when_context_config_missing() -> None: + with pytest.raises(SQLSpecError, match="Migration context must have a config"): + await migration_0002.up(MigrationContext(config=None)) + + +async def test_0002_down_raises_when_context_config_missing() -> None: + with pytest.raises(SQLSpecError, match="Migration context must have a config"): + await migration_0002.down(MigrationContext(config=None)) diff --git a/tests/unit/extensions/test_adk/test_service.py b/tests/unit/extensions/test_adk/test_service.py index e65b3b0ed..d4c7e5be6 100644 --- a/tests/unit/extensions/test_adk/test_service.py +++ b/tests/unit/extensions/test_adk/test_service.py @@ -1,7 +1,7 @@ -"""Unit tests for SQLSpecSessionService — state persistence fix. +"""Unit tests for SQLSpecSessionService. -Tests the NEW contract specified in Chapter 1 of the ADK Clean-Break Overhaul: -- append_event calls append_event_and_update_state (not the old append_event) +Covers the durable-state write contract: +- append_event routes to append_event_and_update_state (or append_event on no-op) - temp: keys are stripped before persisting session state - partial events are not persisted - create_session strips temp: keys from initial state @@ -37,19 +37,20 @@ class MockStore: """ def __init__(self) -> None: - # Track calls to the new combined method self.append_event_and_update_state_calls: list[dict[str, Any]] = [] self.append_event_and_update_state_called = False + self.append_event_calls: list[Any] = [] self.get_session_calls = 0 self.get_session_call_args: list[dict[str, Any]] = [] - # Track calls to create_session self.create_session_calls: list[dict[str, Any]] = [] self.upsert_app_state_calls: list[dict[str, Any]] = [] self.upsert_user_state_calls: list[dict[str, Any]] = [] self.app_state: dict[str, Any] | None = None self.user_state: dict[str, Any] | None = None + self.config = type("_Cfg", (), {"extension_config": {"adk": {}}})() + # Provide a get_session that returns a minimal session record self._session_record = { "id": "s1", @@ -63,11 +64,11 @@ def __init__(self) -> None: async def append_event_and_update_state( self, event_record: Any, - session_id: str, - state: "dict[str, Any]", + app_name: "str | None" = None, + user_id: "str | None" = None, + session_id: "str | None" = None, + state: "dict[str, Any] | None" = None, *, - app_name: str | None = None, - user_id: str | None = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> "dict[str, Any]": @@ -89,26 +90,32 @@ async def append_event_and_update_state( self.user_state = user_state # Return the updated SessionRecord — caller no longer needs a follow-up get_session(). updated = dict(self._session_record) - updated["state"] = state + updated["state"] = state if state is not None else {} updated["update_time"] = datetime.now(timezone.utc) self._session_record = updated return updated async def get_session( - self, session_id: str, *, renew_for: int | timedelta | None = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: int | timedelta | None = None ) -> "dict[str, Any] | None": self.get_session_calls += 1 - self.get_session_call_args.append({"session_id": session_id, "renew_for": renew_for}) + self.get_session_call_args.append({ + "app_name": app_name, + "user_id": user_id, + "session_id": session_id, + "renew_for": renew_for, + }) return self._session_record async def create_session( - self, *, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]" + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> "dict[str, Any]": self.create_session_calls.append({ "session_id": session_id, "app_name": app_name, "user_id": user_id, "state": state, + "owner_id": owner_id, }) return { "id": session_id, @@ -119,11 +126,12 @@ async def create_session( "update_time": datetime.now(timezone.utc), } - # Old method — should NOT be called by the new service async def append_event(self, event_record: Any) -> None: - raise AssertionError("append_event (old method) must not be called — use append_event_and_update_state") + self.append_event_calls.append(event_record) - async def get_events(self, *, session_id: str, after_timestamp: Any = None, limit: Any = None) -> list: + async def get_events( + self, app_name: str, user_id: str, session_id: str, after_timestamp: Any = None, limit: Any = None + ) -> list: return [] async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": @@ -140,10 +148,10 @@ async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, self.upsert_user_state_calls.append({"app_name": app_name, "user_id": user_id, "state": state}) self.user_state = state - async def list_sessions(self, *, app_name: str, user_id: "str | None" = None) -> list: + async def list_sessions(self, app_name: str, user_id: "str | None" = None) -> list: return [] - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: pass @@ -187,7 +195,12 @@ async def test_get_session_forwards_renew_for_to_store() -> None: session = await service.get_session(app_name="app", user_id="u1", session_id="s1", renew_for=renew_for) assert session is not None - assert store.get_session_call_args[0] == {"session_id": "s1", "renew_for": renew_for} + assert store.get_session_call_args[0] == { + "app_name": "app", + "user_id": "u1", + "session_id": "s1", + "renew_for": renew_for, + } @pytest.mark.anyio @@ -270,7 +283,7 @@ async def test_append_event_strips_temp_from_persisted_state() -> None: store = MockStore() service = SQLSpecSessionService(store) # type: ignore[arg-type] session = _make_session(state={"key": "v", "temp:transient": "should_not_persist"}) - event = _make_event() + event = _make_event(state_delta={"key": "v"}) await service.append_event(session, event) @@ -336,7 +349,7 @@ async def test_append_event_passes_correct_session_id_to_store() -> None: store = MockStore() service = SQLSpecSessionService(store) # type: ignore[arg-type] session = _make_session(session_id="my-unique-session-id") - event = _make_event() + event = _make_event(state_delta={"key": "v"}) await service.append_event(session, event) @@ -345,18 +358,26 @@ async def test_append_event_passes_correct_session_id_to_store() -> None: @pytest.mark.anyio -async def test_append_event_event_record_has_5_keys() -> None: - """The event_record passed to the store has exactly 5 keys (new schema).""" +async def test_append_event_event_record_keys_match_full_event_schema() -> None: + """The event_record passed to the store carries the full-event JSON schema.""" store = MockStore() service = SQLSpecSessionService(store) # type: ignore[arg-type] session = _make_session() - event = _make_event() + event = _make_event(state_delta={"key": "v"}) await service.append_event(session, event) last_call = store.append_event_and_update_state_calls[-1] event_record = last_call["event_record"] - assert set(event_record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_data"} + assert set(event_record.keys()) == { + "id", + "app_name", + "user_id", + "session_id", + "invocation_id", + "timestamp", + "event_data", + } @pytest.mark.anyio @@ -365,7 +386,7 @@ async def test_append_event_returns_the_event() -> None: store = MockStore() service = SQLSpecSessionService(store) # type: ignore[arg-type] session = _make_session() - event = _make_event(author="model") + event = _make_event(author="model", state_delta={"k": "v"}) result = await service.append_event(session, event) @@ -373,6 +394,75 @@ async def test_append_event_returns_the_event() -> None: assert result.author == "model" +# --------------------------------------------------------------------------- +# No-op state delta handling +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_append_event_skips_update_when_state_delta_empty() -> None: + """When event.actions.state_delta is empty, the session UPDATE is skipped.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session(state={"key": "v"}) + event = _make_event(state_delta={}) + + await service.append_event(session, event) + + assert len(store.append_event_calls) == 1, "no-op delta must route to append_event (event-only insert)" + assert not store.append_event_and_update_state_called, "no-op delta must NOT trigger append_event_and_update_state" + + +@pytest.mark.anyio +async def test_append_event_runs_full_path_when_state_delta_present() -> None: + """When event.actions.state_delta is non-empty, take the full atomic path.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session(state={"key": "v0"}) + event = _make_event(state_delta={"key": "v1"}) + + await service.append_event(session, event) + + assert store.append_event_and_update_state_called + assert store.append_event_calls == [], "non-empty delta must NOT take the no-op event-only path" + + +@pytest.mark.anyio +async def test_append_event_noop_event_record_carries_full_schema() -> None: + """The event_record sent to append_event on the no-op path carries the full key set.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session() + event = _make_event(state_delta={}) + + await service.append_event(session, event) + + assert len(store.append_event_calls) == 1 + event_record = store.append_event_calls[0] + assert set(event_record.keys()) == { + "id", + "app_name", + "user_id", + "session_id", + "invocation_id", + "timestamp", + "event_data", + } + + +@pytest.mark.anyio +async def test_append_event_noop_does_not_advance_last_update_time() -> None: + """On no-op, session.last_update_time matches the pre-append storage timestamp.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session(state={"key": "v"}) + event = _make_event(state_delta={}) + + await service.append_event(session, event) + + assert session.last_update_time == store._session_record["update_time"].timestamp() + + # --------------------------------------------------------------------------- # create_session — strips temp: keys from initial state # --------------------------------------------------------------------------- @@ -466,13 +556,10 @@ def __init__(self, *, stale_marker: bool = False, stale_timestamp: bool = False) self._stale_timestamp = stale_timestamp async def get_session( - self, session_id: str, *, renew_for: int | timedelta | None = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: int | timedelta | None = None ) -> "dict[str, Any] | None": record = dict(self._session_record) if self._stale_marker or self._stale_timestamp: - # Simulate a storage-side update by advancing update_time - from datetime import timedelta - record["update_time"] = record["update_time"] + timedelta(seconds=10) # type: ignore[operator] return record @@ -481,7 +568,7 @@ class MissingSessionStore(MockStore): """Mock store where the session disappears between load and append.""" async def get_session( - self, session_id: str, *, renew_for: int | timedelta | None = None + self, app_name: str, user_id: str, session_id: str, *, renew_for: int | timedelta | None = None ) -> "dict[str, Any] | None": return None diff --git a/tests/unit/extensions/test_adk/test_store_config.py b/tests/unit/extensions/test_adk/test_store_config.py index 3bf4d016f..b8a2d837e 100644 --- a/tests/unit/extensions/test_adk/test_store_config.py +++ b/tests/unit/extensions/test_adk/test_store_config.py @@ -280,11 +280,9 @@ def test_session_store_contract_declares_schema_parity_hooks() -> None: def test_session_store_resolves_schema_parity_table_names() -> None: store = _AsyncSessionStore( _Config({ - "schema": { - "app_state_table": "agent_app_states", - "user_state_table": "agent_user_states", - "metadata_table": "agent_metadata", - } + "app_state_table": "agent_app_states", + "user_state_table": "agent_user_states", + "metadata_table": "agent_metadata", }) ) @@ -300,15 +298,16 @@ def test_session_store_uses_singular_default_table_names() -> None: assert store.events_table == "adk_event" assert store.app_state_table == "adk_app_state" assert store.user_state_table == "adk_user_state" - assert store.metadata_table == "adk_internal_metadata" + assert store.metadata_table == "adk_metadata" @pytest.mark.anyio async def test_adk_migration_up_includes_schema_parity_tables(monkeypatch: pytest.MonkeyPatch) -> None: - migration = __import__("sqlspec.extensions.adk.migrations.0001_create_adk_tables", fromlist=["up"]) + migration = __import__("sqlspec.extensions.adk.migrations.0002_reset_adk_tables", fromlist=["up"]) context = type("MigrationContext", (), {"config": _Config()})() monkeypatch.setattr(migration, "_get_store_class", lambda _context: _MigrationSessionStore) + monkeypatch.setattr(migration, "_get_memory_store_class", lambda _context: None) monkeypatch.setattr(migration, "_is_memory_enabled", lambda _context: False) statements = await migration.up(context) @@ -326,7 +325,7 @@ async def test_adk_migration_up_includes_schema_parity_tables(monkeypatch: pytes @pytest.mark.parametrize("field", ["app_state_table", "user_state_table", "metadata_table"]) def test_session_store_validates_schema_parity_table_names(field: str) -> None: with pytest.raises(ValueError, match="Invalid table name"): - _AsyncSessionStore(_Config({"schema": {field: "invalid-name"}})) + _AsyncSessionStore(_Config({field: "invalid-name"})) def test_session_store_contract_get_session_accepts_renew_for_kwarg() -> None: diff --git a/tests/unit/extensions/test_adk/test_versioning.py b/tests/unit/extensions/test_adk/test_versioning.py deleted file mode 100644 index 1b43ef70a..000000000 --- a/tests/unit/extensions/test_adk/test_versioning.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Tests for ADK schema and payload version planning.""" - -from typing import Any - -import pytest - -from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.extensions.adk._config_utils import _get_adk_version_plan -from sqlspec.extensions.adk._versioning import ( - ADK_ARTIFACT_PAYLOAD_VERSION, - ADK_EVENT_PAYLOAD_VERSION, - ADK_MEMORY_PAYLOAD_VERSION, - ADK_PAYLOAD_VERSION_KEYS, - ADK_SCHEMA_VERSION, - ADK_SCHEMA_VERSION_KEY, - ADK_STATE_PAYLOAD_VERSION, - ADKVersionPlan, - validate_adk_version_plan, -) - - -class _Config: - extension_config: dict[str, dict[str, Any]] - - def __init__(self, adk_config: dict[str, Any]) -> None: - self.extension_config = {"adk": adk_config} - - -def test_default_version_plan_matches_clean_break_v1_contract() -> None: - plan = _get_adk_version_plan(_Config({})) - - assert plan == ADKVersionPlan( - schema_version=ADK_SCHEMA_VERSION, - event_payload_version=ADK_EVENT_PAYLOAD_VERSION, - state_payload_version=ADK_STATE_PAYLOAD_VERSION, - memory_payload_version=ADK_MEMORY_PAYLOAD_VERSION, - artifact_payload_version=ADK_ARTIFACT_PAYLOAD_VERSION, - ) - - -def test_schema_version_key_matches_official_adk_metadata_key() -> None: - assert ADK_SCHEMA_VERSION_KEY == "schema_version" - - -def test_version_plan_metadata_items_include_schema_and_payload_versions() -> None: - metadata_items = dict(_get_adk_version_plan(_Config({})).metadata_items()) - - assert metadata_items == { - ADK_SCHEMA_VERSION_KEY: "1", - ADK_PAYLOAD_VERSION_KEYS["event"]: "1", - ADK_PAYLOAD_VERSION_KEYS["state"]: "1", - ADK_PAYLOAD_VERSION_KEYS["memory"]: "1", - ADK_PAYLOAD_VERSION_KEYS["artifact"]: "1", - } - - -def test_nested_schema_payload_versions_override_defaults() -> None: - plan = _get_adk_version_plan( - _Config({ - "schema": {"schema_version": 1, "payload_versions": {"event": 1, "state": 1, "memory": 1, "artifact": 1}} - }) - ) - - assert plan.event_payload_version == 1 - assert plan.state_payload_version == 1 - assert plan.memory_payload_version == 1 - assert plan.artifact_payload_version == 1 - - -@pytest.mark.parametrize( - "adk_config", - [ - {"schema": {"schema_version": 2}}, - {"schema": {"payload_versions": {"event": 2}}}, - {"schema": {"payload_versions": {"state": 2}}}, - {"schema": {"payload_versions": {"memory": 2}}}, - {"schema": {"payload_versions": {"artifact": 2}}}, - ], -) -def test_unsupported_schema_or_payload_versions_raise_configuration_error(adk_config: dict[str, Any]) -> None: - with pytest.raises(ImproperConfigurationError): - _get_adk_version_plan(_Config(adk_config)) - - -def test_validate_adk_version_plan_accepts_supported_clean_break_plan() -> None: - validate_adk_version_plan(ADKVersionPlan()) From cd1c9b933d79162559df1acb6f0906f9255d7eb7 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Tue, 26 May 2026 00:16:05 +0000 Subject: [PATCH 27/29] refactor(adk): streamline code formatting for consistency and readability --- sqlspec/extensions/adk/_config_utils.py | 4 +--- .../adk/migrations/0002_reset_adk_tables.py | 22 ++++++++----------- .../adapters/_adk_contract_helpers.py | 2 -- .../extensions/adk/test_memory_store.py | 4 +--- 4 files changed, 11 insertions(+), 21 deletions(-) diff --git a/sqlspec/extensions/adk/_config_utils.py b/sqlspec/extensions/adk/_config_utils.py index 2d0716a35..ac9619a11 100644 --- a/sqlspec/extensions/adk/_config_utils.py +++ b/sqlspec/extensions/adk/_config_utils.py @@ -105,9 +105,7 @@ def _get_adk_artifact_store_config(config: _ADKConfigSource) -> _ADKArtifactStor """Return normalized artifact store settings.""" adk_config = _get_adk_config_from_extension(config) - result: _ADKArtifactStoreConfig = { - "artifact_table": str(adk_config.get("artifact_table") or "adk_artifact") - } + result: _ADKArtifactStoreConfig = {"artifact_table": str(adk_config.get("artifact_table") or "adk_artifact")} storage_uri = adk_config.get("artifact_storage_uri") if storage_uri is not None: result["storage_uri"] = str(storage_uri) diff --git a/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py b/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py index ff7e76c10..4b01ccf12 100644 --- a/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py +++ b/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py @@ -68,22 +68,18 @@ async def up(context: "MigrationContext | None" = None) -> "list[str]": if memory_store_class is not None: memory_store = memory_store_class(config=context.config) statements.extend(memory_store._get_drop_memory_table_sql()) # pyright: ignore[reportPrivateUsage] - log_with_context( - logger, logging.DEBUG, "adk.migration.reset.memory.drop", table_name=memory_store.memory_table - ) + log_with_context(logger, logging.DEBUG, "adk.migration.reset.memory.drop", table_name=memory_store.memory_table) statements.extend(store_instance._get_drop_tables_sql()) # pyright: ignore[reportPrivateUsage] - statements.extend( - [ - await store_instance._get_create_sessions_table_sql(), # pyright: ignore[reportPrivateUsage] - await store_instance._get_create_events_table_sql(), # pyright: ignore[reportPrivateUsage] - await store_instance._get_create_app_states_table_sql(), # pyright: ignore[reportPrivateUsage] - await store_instance._get_create_user_states_table_sql(), # pyright: ignore[reportPrivateUsage] - await store_instance._get_create_metadata_table_sql(), # pyright: ignore[reportPrivateUsage] - await store_instance._get_seed_metadata_sql(), # pyright: ignore[reportPrivateUsage] - ] - ) + statements.extend([ + await store_instance._get_create_sessions_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_create_events_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_create_app_states_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_create_user_states_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_create_metadata_table_sql(), # pyright: ignore[reportPrivateUsage] + await store_instance._get_seed_metadata_sql(), # pyright: ignore[reportPrivateUsage] + ]) if _is_memory_enabled(context) and memory_store_class is not None: memory_store = memory_store_class(config=context.config) diff --git a/tests/integration/adapters/_adk_contract_helpers.py b/tests/integration/adapters/_adk_contract_helpers.py index 0fe3549b5..e064fb260 100644 --- a/tests/integration/adapters/_adk_contract_helpers.py +++ b/tests/integration/adapters/_adk_contract_helpers.py @@ -673,5 +673,3 @@ async def assert_memory_store_contract(store: MemoryStore, *, marker: str) -> No fresh_results = await store.search_entries("fresh", app_name, user_id, limit=10) assert len(fresh_results) == 1 assert fresh_results[0]["event_id"] == fresh_record["event_id"] - - diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py index e1edabb41..870fa81e3 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py @@ -96,9 +96,7 @@ async def test_sqlite_memory_store_disabled_lifecycle() -> None: await store.create_tables() with config.provide_connection() as conn: - cursor = conn.execute( - "SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", ("adk_memory",) - ) + cursor = conn.execute("SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", ("adk_memory",)) row = cursor.fetchone() assert row is None From ae43112d64ced0c7413eae177749ae05410f088c Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Tue, 26 May 2026 01:33:24 +0000 Subject: [PATCH 28/29] refactor(adk): update parameter types in append_event_and_update_state for clarity --- tests/unit/extensions/test_adk/test_service.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/extensions/test_adk/test_service.py b/tests/unit/extensions/test_adk/test_service.py index d4c7e5be6..0e43a12e3 100644 --- a/tests/unit/extensions/test_adk/test_service.py +++ b/tests/unit/extensions/test_adk/test_service.py @@ -647,11 +647,11 @@ class FailingStore(MockStore): async def append_event_and_update_state( self, event_record: Any, - session_id: str, - state: Any, + app_name: "str | None" = None, + user_id: "str | None" = None, + session_id: "str | None" = None, + state: "dict[str, Any] | None" = None, *, - app_name: str | None = None, - user_id: str | None = None, app_state: "dict[str, Any] | None" = None, user_state: "dict[str, Any] | None" = None, ) -> "dict[str, Any]": From 3017d21347644db468f37aed087424e3a179f0b8 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Tue, 26 May 2026 03:34:45 +0000 Subject: [PATCH 29/29] fix: address adk scoped state review --- docs/extensions/adk/optimizations.rst | 9 +- docs/extensions/adk/schema.rst | 9 +- docs/extensions/adk/scoped_state.rst | 6 +- sqlspec/adapters/bigquery/adk/store.py | 44 +++++-- .../adk/migrations/0002_reset_adk_tables.py | 59 ++++++++- sqlspec/extensions/adk/service.py | 60 +++++---- sqlspec/extensions/adk/store.py | 15 ++- .../adapters/_adk_contract_helpers.py | 44 +++++++ .../aiosqlite/extensions/adk/test_store.py | 10 ++ .../sqlite/extensions/adk/test_store.py | 10 ++ tests/unit/adapters/test_bigquery_adk.py | 51 +++++++- .../extensions/test_adk/test_migrations.py | 120 ++++++++++++++++++ .../unit/extensions/test_adk/test_service.py | 117 +++++++++++++++-- 13 files changed, 483 insertions(+), 71 deletions(-) diff --git a/docs/extensions/adk/optimizations.rst b/docs/extensions/adk/optimizations.rst index 735a86d0e..c6c4c60d8 100644 --- a/docs/extensions/adk/optimizations.rst +++ b/docs/extensions/adk/optimizations.rst @@ -24,10 +24,11 @@ V2 — Skip no-op session UPDATE Status: planned. -When ``event.actions.state_delta`` is empty (and no scoped-state delta is -provided), the store skips the session ``UPDATE`` and instead bumps -``update_time`` via a lightweight ``UPDATE ... SET update_time = CURRENT_TIMESTAMP`` -or omits the bump entirely depending on the freshness contract. +When ``event.actions.state_delta`` is empty, the service still routes the event +through ``append_event_and_update_state()`` so ``update_time`` advances for +message-only/tool events. A future store-level optimization may narrow that +write to a lightweight ``UPDATE ... SET update_time = CURRENT_TIMESTAMP`` while +preserving the same freshness contract. V3 — Generated columns from JSON --------------------------------- diff --git a/docs/extensions/adk/schema.rst b/docs/extensions/adk/schema.rst index ebf5f1f05..467052a82 100644 --- a/docs/extensions/adk/schema.rst +++ b/docs/extensions/adk/schema.rst @@ -261,10 +261,13 @@ session mutations. It atomically: 1. Inserts the event record into the events table. 2. Updates the session-scoped durable state in the sessions table. +3. Optionally replaces touched app/user scoped snapshots. Both operations succeed together or fail together within a single database -transaction. App/user scoped state writes are routed separately by the session -service through the dedicated scoped-state hooks. +transaction. The service passes ``app_state`` and ``user_state`` only when an +event delta touches those scopes. When provided, each value is the full merged +snapshot for that scope, not the raw delta, because adapter stores commonly use +replace/upsert semantics. .. code-block:: python @@ -273,6 +276,8 @@ service through the dedicated scoped-state hooks. event_record=event_record, session_id=session.id, state=session_state, # temp/app/user keys already stripped + app_state=merged_app_state, # None if the event did not touch app:* keys + user_state=merged_user_state, # None if the event did not touch user:* keys ) **Why this matters:** diff --git a/docs/extensions/adk/scoped_state.rst b/docs/extensions/adk/scoped_state.rst index 99be15c83..3d8ddd204 100644 --- a/docs/extensions/adk/scoped_state.rst +++ b/docs/extensions/adk/scoped_state.rst @@ -39,8 +39,10 @@ persistence. Durable keys are split into three buckets: - Unprefixed keys are written to the session row. ``append_event_and_update_state()`` remains the store-level atomic boundary for -the event row and the session row. The scoped app/user writes are routed by the -service through the dedicated scoped-state store hooks. +the event row and the session row. When an event delta touches ``app:`` or +``user:`` keys, the service reads the latest scoped row, overlays the delta, and +passes the full merged scoped snapshot to the store. Untouched scopes are passed +as ``None`` so replace-style stores do not drop existing shared keys. Read Behavior ============= diff --git a/sqlspec/adapters/bigquery/adk/store.py b/sqlspec/adapters/bigquery/adk/store.py index eae9e5853..8fd057d9b 100644 --- a/sqlspec/adapters/bigquery/adk/store.py +++ b/sqlspec/adapters/bigquery/adk/store.py @@ -16,8 +16,8 @@ When ``ADKConfig.bigquery.session_lookup_window_days`` is set, list reads constrain ``create_time`` so partitioned scans stay cheap. ``ADKConfig.retention.event_ttl_seconds`` -maps to ``partition_expiration_days`` on the events table when ``require_partition_filter`` -is enabled. JSON is stored using BigQuery's native ``JSON`` type. +maps to ``partition_expiration_days`` on the events table only. JSON is stored +using BigQuery's native ``JSON`` type. """ import math @@ -64,7 +64,7 @@ def __init__(self, config: BigQueryConfig) -> None: self._partition_expiration_days: int | None = ( max(1, math.ceil(int(ttl_seconds) / 86400)) if ttl_seconds else None ) - self._require_partition_filter: bool = bool(bigquery_config.get("require_partition_filter", True)) + self._require_partition_filter: bool = bool(bigquery_config.get("require_partition_filter", False)) dataset_id = config.connection_config.get("dataset_id") self._dataset_qualifier: str = f"{dataset_id}." if dataset_id else "" @@ -73,6 +73,13 @@ def _qualified(self, table: str) -> str: """Return the dataset-qualified table identifier when available.""" return f"{self._dataset_qualifier}{table}" + def _partition_filter(self, column: str, *, alias: str | None = None) -> str: + """Return a broad partition predicate for opt-in require_partition_filter mode.""" + if not self._require_partition_filter: + return "" + qualified_column = f"{alias}.{column}" if alias else column + return f" AND {qualified_column} IS NOT NULL" + # ------------------------------------------------------------------ # Session CRUD # ------------------------------------------------------------------ @@ -210,6 +217,7 @@ def _get_session( SELECT id, app_name, user_id, state, create_time, update_time FROM {self._qualified(self._session_table)} WHERE app_name = @app_name AND user_id = @user_id AND id = @id + {self._partition_filter("create_time").strip()} LIMIT 1 """ rows = self._run_query( @@ -238,6 +246,7 @@ def _update_session_touch(self, app_name: str, user_id: str, session_id: str) -> UPDATE {self._qualified(self._session_table)} SET update_time = CURRENT_TIMESTAMP() WHERE app_name = @app_name AND user_id = @user_id AND id = @id + {self._partition_filter("create_time").strip()} """ self._run_query( sql, @@ -253,6 +262,7 @@ def _update_session_state(self, app_name: str, user_id: str, session_id: str, st UPDATE {self._qualified(self._session_table)} SET state = @state, update_time = CURRENT_TIMESTAMP() WHERE app_name = @app_name AND user_id = @user_id AND id = @id + {self._partition_filter("create_time").strip()} """ self._run_query( sql, @@ -295,8 +305,14 @@ def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[S return records def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: - events_sql = f"DELETE FROM {self._qualified(self._events_table)} WHERE session_id = @id" - sessions_sql = f"DELETE FROM {self._qualified(self._session_table)} WHERE app_name = @app_name AND user_id = @user_id AND id = @id" + events_sql = ( + f"DELETE FROM {self._qualified(self._events_table)} " + f"WHERE session_id = @id{self._partition_filter('timestamp')}" + ) + sessions_sql = ( + f"DELETE FROM {self._qualified(self._session_table)} " + f"WHERE app_name = @app_name AND user_id = @user_id AND id = @id{self._partition_filter('create_time')}" + ) self._run_query(events_sql, [self._query_param("id", session_id)]) self._run_query( sessions_sql, @@ -359,6 +375,8 @@ def _get_events( FROM {self._qualified(self._events_table)} e JOIN {self._qualified(self._session_table)} s ON e.session_id = s.id WHERE s.app_name = @app_name AND s.user_id = @user_id AND e.session_id = @session_id + {self._partition_filter("timestamp", alias="e").strip()} + {self._partition_filter("create_time", alias="s").strip()} """ params = [ self._query_param("app_name", app_name), @@ -387,14 +405,20 @@ def _get_events( ] def _delete_expired_events(self, before: datetime) -> int: - sql = f"DELETE FROM {self._qualified(self._events_table)} WHERE timestamp < @before" + sql = ( + f"DELETE FROM {self._qualified(self._events_table)} " + f"WHERE timestamp < @before{self._partition_filter('timestamp')}" + ) # BigQuery jobs don't expose affected-rows reliably across all versions; # callers treat the count as best-effort and may consult job statistics if needed. self._run_query(sql, [self._query_param("before", before, bq_type="TIMESTAMP")]) return 0 def _delete_idle_sessions(self, updated_before: datetime) -> int: - sql = f"DELETE FROM {self._qualified(self._session_table)} WHERE update_time < @before" + sql = ( + f"DELETE FROM {self._qualified(self._session_table)} " + f"WHERE update_time < @before{self._partition_filter('create_time')}" + ) self._run_query(sql, [self._query_param("before", updated_before, bq_type="TIMESTAMP")]) return 0 @@ -470,11 +494,11 @@ def _set_metadata(self, key: str, value: str) -> None: # DDL # ------------------------------------------------------------------ - def _partition_options(self) -> str: + def _partition_options(self, *, include_expiration: bool = False) -> str: parts: list[str] = [] if self._require_partition_filter: parts.append("require_partition_filter = TRUE") - if self._partition_expiration_days is not None: + if include_expiration and self._partition_expiration_days is not None: parts.append(f"partition_expiration_days = {self._partition_expiration_days}") return f"\nOPTIONS({', '.join(parts)})" if parts else "" @@ -502,7 +526,7 @@ async def _get_create_events_table_sql(self) -> str: event_data JSON ) PARTITION BY DATE(timestamp) - CLUSTER BY session_id, id{self._partition_options()} + CLUSTER BY session_id, id{self._partition_options(include_expiration=True)} """ async def _get_create_app_states_table_sql(self) -> str: diff --git a/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py b/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py index 4b01ccf12..9b25de7fa 100644 --- a/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py +++ b/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py @@ -8,7 +8,7 @@ """ import logging -from typing import TYPE_CHECKING, NoReturn, cast +from typing import TYPE_CHECKING, Final, NoReturn, cast from sqlspec.exceptions import SQLSpecError from sqlspec.extensions.adk._config_utils import ( @@ -27,6 +27,9 @@ __all__ = ("down", "up") +MIN_DROP_TABLE_TOKENS: Final = 3 +MIN_DROP_TABLE_IF_EXISTS_TOKENS: Final = 5 + def _raise_missing_config() -> NoReturn: msg = "Migration context must have a config to determine store class" @@ -55,22 +58,61 @@ def _is_memory_enabled(context: "MigrationContext | None") -> bool: return _is_adk_memory_migration_enabled(context.config) +def _is_spanner_store(store_instance: "BaseAsyncADKStore") -> bool: + return getattr(store_instance, "connector_name", None) == "spanner" + + +def _spanner_existing_tables(context: "MigrationContext") -> "set[str]": + if context.config is None: + _raise_missing_config() + database = context.config.get_database() + return {table.table_id for table in database.list_tables()} + + +def _drop_table_name(statement: str) -> str | None: + tokens = statement.strip().split() + if len(tokens) >= MIN_DROP_TABLE_TOKENS and tokens[0].upper() == "DROP" and tokens[1].upper() == "TABLE": + if ( + len(tokens) >= MIN_DROP_TABLE_IF_EXISTS_TOKENS + and tokens[2].upper() == "IF" + and tokens[3].upper() == "EXISTS" + ): + return tokens[4] + return tokens[2] + return None + + +def _filter_existing_table_drops(statements: "list[str]", existing_tables: "set[str] | None") -> "list[str]": + if existing_tables is None: + return statements + return [statement for statement in statements if (_drop_table_name(statement) in existing_tables)] + + +def _filter_memory_drops(statements: "list[str]", memory_table: str, existing_tables: "set[str] | None") -> "list[str]": + if existing_tables is None or memory_table in existing_tables: + return statements + return [] + + async def up(context: "MigrationContext | None" = None) -> "list[str]": if context is None or context.config is None: _raise_missing_config() store_class = _get_store_class(context) store_instance = store_class(config=context.config) + existing_tables = _spanner_existing_tables(context) if _is_spanner_store(store_instance) else None statements: list[str] = [] memory_store_class = _get_memory_store_class(context) if memory_store_class is not None: memory_store = memory_store_class(config=context.config) - statements.extend(memory_store._get_drop_memory_table_sql()) # pyright: ignore[reportPrivateUsage] + memory_drops = memory_store._get_drop_memory_table_sql() # pyright: ignore[reportPrivateUsage] + statements.extend(_filter_memory_drops(memory_drops, memory_store.memory_table, existing_tables)) log_with_context(logger, logging.DEBUG, "adk.migration.reset.memory.drop", table_name=memory_store.memory_table) - statements.extend(store_instance._get_drop_tables_sql()) # pyright: ignore[reportPrivateUsage] + drop_statements = store_instance._get_drop_tables_sql() # pyright: ignore[reportPrivateUsage] + statements.extend(_filter_existing_table_drops(drop_statements, existing_tables)) statements.extend([ await store_instance._get_create_sessions_table_sql(), # pyright: ignore[reportPrivateUsage] @@ -100,15 +142,18 @@ async def down(context: "MigrationContext | None" = None) -> "list[str]": _raise_missing_config() statements: list[str] = [] + store_class = _get_store_class(context) + store_instance = store_class(config=context.config) + existing_tables = _spanner_existing_tables(context) if _is_spanner_store(store_instance) else None if _is_memory_enabled(context): memory_store_class = _get_memory_store_class(context) if memory_store_class is not None: memory_store = memory_store_class(config=context.config) - statements.extend(memory_store._get_drop_memory_table_sql()) # pyright: ignore[reportPrivateUsage] + memory_drops = memory_store._get_drop_memory_table_sql() # pyright: ignore[reportPrivateUsage] + statements.extend(_filter_memory_drops(memory_drops, memory_store.memory_table, existing_tables)) - store_class = _get_store_class(context) - store_instance = store_class(config=context.config) - statements.extend(store_instance._get_drop_tables_sql()) # pyright: ignore[reportPrivateUsage] + drop_statements = store_instance._get_drop_tables_sql() # pyright: ignore[reportPrivateUsage] + statements.extend(_filter_existing_table_drops(drop_statements, existing_tables)) return statements diff --git a/sqlspec/extensions/adk/service.py b/sqlspec/extensions/adk/service.py index 1c32d1ea8..0635158f0 100644 --- a/sqlspec/extensions/adk/service.py +++ b/sqlspec/extensions/adk/service.py @@ -88,14 +88,23 @@ async def create_session( state = {} persisted_state = filter_temp_state(state) - app_state, user_state, session_state = split_scoped_state(persisted_state) + app_state_delta, user_state_delta, session_state = split_scoped_state(persisted_state) + current_app_state = await self._store.get_app_state(app_name) + current_user_state = await self._store.get_user_state(app_name, user_id) + + app_state = dict(current_app_state or {}) + if app_state_delta: + app_state.update(app_state_delta) + user_state = dict(current_user_state or {}) + if user_state_delta: + user_state.update(user_state_delta) record = await self._store.create_session( session_id=session_id, app_name=app_name, user_id=user_id, state=session_state ) - if app_state: + if app_state_delta: await self._store.upsert_app_state(app_name, app_state) - if user_state: + if user_state_delta: await self._store.upsert_user_state(app_name, user_id, user_state) record["state"] = merge_scoped_state(record["state"], app_state, user_state) log_with_context( @@ -255,13 +264,6 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": event=event, app_name=session.app_name, user_id=session.user_id, session_id=session.id ) - # Build durable state: current state minus temp keys, plus the - # event's state delta (temp keys already stripped by _trim above). - durable_state = filter_temp_state(session.state) - if event.actions and event.actions.state_delta: - durable_state.update(event.actions.state_delta) - app_state, user_state, session_state = split_scoped_state(durable_state) - # --- Stale-session detection --- current_record = await self._store.get_session(session.app_name, session.user_id, session.id) if current_record is None: @@ -284,20 +286,30 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": raise ValueError(msg) state_delta = (event.actions.state_delta if event.actions else None) or {} - - if not state_delta: - await self._store.append_event(event_record) - updated_record = current_record - else: - updated_record = await self._store.append_event_and_update_state( - event_record=event_record, - app_name=session.app_name, - user_id=session.user_id, - session_id=session.id, - state=session_state, - app_state=app_state or None, - user_state=user_state or None, - ) + app_state_delta, user_state_delta, session_state_delta = split_scoped_state(filter_temp_state(state_delta)) + + session_state = dict(current_record["state"]) + session_state.update(session_state_delta) + + app_state = None + if app_state_delta: + app_state = dict(await self._store.get_app_state(session.app_name) or {}) + app_state.update(app_state_delta) + + user_state = None + if user_state_delta: + user_state = dict(await self._store.get_user_state(session.app_name, session.user_id) or {}) + user_state.update(user_state_delta) + + updated_record = await self._store.append_event_and_update_state( + event_record=event_record, + app_name=session.app_name, + user_id=session.user_id, + session_id=session.id, + state=session_state, + app_state=app_state, + user_state=user_state, + ) updated_record["state"] = merge_scoped_state(updated_record["state"], app_state, user_state) # Use the returned record directly — saves a round-trip vs a follow-up get_session(). diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index 0fdfbcd54..aee35e2bb 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -306,10 +306,11 @@ async def append_event_and_update_state( and the updated session record is returned in the same round-trip so callers don't need a follow-up read. - When ``app_state`` is provided (non-None), it is upserted into the - ``app_state_table`` for ``app_name``. When ``user_state`` is provided, - it is upserted into the ``user_state_table`` for ``(app_name, user_id)``. - Empty dicts are treated as "no scoped delta" and skipped. + When ``app_state`` is provided (non-None), it is a full merged + app-scoped snapshot to replace/upsert for ``app_name``. When + ``user_state`` is provided, it is a full merged user-scoped snapshot to + replace/upsert for ``(app_name, user_id)``. ``None`` means that scope + was untouched by the event and must not be written. Args: event_record: Event record to store. @@ -318,8 +319,10 @@ async def append_event_and_update_state( session_id: Session identifier whose state should be updated. state: Post-append durable session-scoped state snapshot (``temp:`` keys already stripped by the service layer). - app_state: App-scoped state delta (``app:*`` keys) to upsert atomically. - user_state: User-scoped state delta (``user:*`` keys) to upsert atomically. + app_state: Full app-scoped state snapshot (``app:*`` keys) to + upsert atomically, or ``None`` when untouched. + user_state: Full user-scoped state snapshot (``user:*`` keys) to + upsert atomically, or ``None`` when untouched. Returns: The updated SessionRecord reflecting the new state and update_time. diff --git a/tests/integration/adapters/_adk_contract_helpers.py b/tests/integration/adapters/_adk_contract_helpers.py index e064fb260..3b3f508c3 100644 --- a/tests/integration/adapters/_adk_contract_helpers.py +++ b/tests/integration/adapters/_adk_contract_helpers.py @@ -16,6 +16,7 @@ "assert_session_event_store_contract", "assert_session_get_session_renewal_contract", "assert_session_scoped_state_contract", + "assert_session_service_event_only_touch_contract", "assert_session_sibling_app_isolation", "assert_session_sibling_user_isolation", "assert_session_table_lifecycle_contract", @@ -327,6 +328,11 @@ async def assert_session_scoped_state_contract(store: SessionEventStore, *, mark assert fetched_b is not None assert fetched_b.state == {"app:counter": 1, "user:theme": "dark"} + session_c = await service.create_session( + app_name=app_name, user_id=user_id, session_id=_contract_key(marker, "scoped-session-c"), state={} + ) + assert session_c.state == {"app:counter": 1, "user:theme": "dark"} + fetched_other_user = await service.get_session( app_name=app_name, user_id=other_user_id, session_id=other_user_session.id ) @@ -436,6 +442,44 @@ async def assert_session_temp_state_not_persisted(store: SessionEventStore, *, m assert fetched.state == {"turn": 1} +async def assert_session_service_event_only_touch_contract(store: SessionEventStore, *, marker: str) -> None: + """Assert service-level message-only events persist and touch update_time.""" + from google.adk.events.event import Event + from google.adk.events.event_actions import EventActions + + service = SQLSpecSessionService(store) # type: ignore[arg-type] + app_name = _contract_key(marker, "event-only-app") + user_id = _contract_key(marker, "event-only-user") + session_id = _contract_key(marker, "event-only-session") + + created = await service.create_session( + app_name=app_name, user_id=user_id, session_id=session_id, state={"phase": "created"} + ) + session = await service.get_session(app_name=app_name, user_id=user_id, session_id=created.id) + assert session is not None + original_update_time = session.last_update_time + await asyncio.sleep(0.02) + + event = Event( + invocation_id=_contract_key(marker, "event-only-invocation"), + author="model", + timestamp=datetime.now(timezone.utc).timestamp(), + actions=EventActions(state_delta={}), + ) + + await service.append_event(session, event) + + raw_session = await store.get_session(app_name, user_id, session_id) + assert raw_session is not None + assert raw_session["state"] == {"phase": "created"} + assert _as_utc(raw_session["update_time"]).timestamp() > original_update_time + assert session.last_update_time > original_update_time + + events = await store.get_events(app_name, user_id, session_id) + assert len(events) == 1 + assert events[0]["invocation_id"] == event.invocation_id + + async def assert_session_empty_state_roundtrip(store: SessionEventStore, *, marker: str) -> None: """Assert empty session/app/user state survives the append_event_and_update_state round-trip.""" app_name = _contract_key(marker, "empty-app") diff --git a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py index 2e978a808..fa43e59c6 100644 --- a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py @@ -15,6 +15,7 @@ assert_session_event_store_contract, assert_session_get_session_renewal_contract, assert_session_scoped_state_contract, + assert_session_service_event_only_touch_contract, assert_session_sibling_app_isolation, assert_session_sibling_user_isolation, assert_session_table_lifecycle_contract, @@ -129,6 +130,15 @@ async def test_aiosqlite_session_temp_state_not_persisted(tmp_path: Path) -> Non await config.close_pool() +async def test_aiosqlite_session_service_event_only_touch_contract(tmp_path: Path) -> None: + """AioSQLite service-level event-only appends advance update_time.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_service_event_only_touch_contract(store, marker="aiosqlite") + finally: + await config.close_pool() + + async def test_aiosqlite_session_empty_state_roundtrip(tmp_path: Path) -> None: """AioSQLite preserves empty session/app/user state through append_event_and_update_state.""" config, store = await _build_store(tmp_path) diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_store.py index 663651c8a..ab3b330c3 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_store.py @@ -15,6 +15,7 @@ assert_session_event_store_contract, assert_session_get_session_renewal_contract, assert_session_scoped_state_contract, + assert_session_service_event_only_touch_contract, assert_session_sibling_app_isolation, assert_session_sibling_user_isolation, assert_session_table_lifecycle_contract, @@ -109,6 +110,15 @@ async def test_sqlite_session_temp_state_not_persisted(tmp_path: Path) -> None: config.close_pool() +async def test_sqlite_session_service_event_only_touch_contract(tmp_path: Path) -> None: + """SQLite service-level event-only appends advance update_time.""" + config, store = await _build_store(tmp_path) + try: + await assert_session_service_event_only_touch_contract(store, marker="sqlite") + finally: + config.close_pool() + + async def test_sqlite_session_empty_state_roundtrip(tmp_path: Path) -> None: """SQLite preserves empty session/app/user state through append_event_and_update_state.""" config, store = await _build_store(tmp_path) diff --git a/tests/unit/adapters/test_bigquery_adk.py b/tests/unit/adapters/test_bigquery_adk.py index a82e5a9db..d8bbbd3ed 100644 --- a/tests/unit/adapters/test_bigquery_adk.py +++ b/tests/unit/adapters/test_bigquery_adk.py @@ -2,6 +2,7 @@ import asyncio import importlib.util +from datetime import datetime, timezone from typing import Any import pytest @@ -33,7 +34,7 @@ def test_bigquery_adk_store_instantiates_with_defaults() -> None: assert store.metadata_table == "adk_metadata" assert store._dataset_qualifier == "test_dataset." assert store._lookup_window_days == 30 - assert store._require_partition_filter is True + assert store._require_partition_filter is False assert store._partition_expiration_days is None @@ -49,14 +50,21 @@ def test_bigquery_adk_store_derives_partition_expiration_from_retention() -> Non assert store._partition_expiration_days == 30 -def test_bigquery_adk_session_ddl_is_partitioned_and_clustered() -> None: +def test_bigquery_adk_store_honours_explicit_partition_filter_opt_in() -> None: + """Partition filters are opt-in because BigQuery DML rejects unfiltered partitioned table touches.""" + store = _make_store({"bigquery": {"require_partition_filter": True}}) + + assert store._require_partition_filter is True + + +def test_bigquery_adk_session_ddl_is_partitioned_and_clustered_without_filter_by_default() -> None: """Sessions table DDL has DATE partitioning + clustering on app_name/user_id.""" store = _make_store() ddl = asyncio.run(store._get_create_sessions_table_sql()) assert "PARTITION BY DATE(create_time)" in ddl assert "CLUSTER BY app_name, user_id, id" in ddl assert "test_dataset.adk_session" in ddl - assert "require_partition_filter = TRUE" in ddl + assert "require_partition_filter = TRUE" not in ddl def test_bigquery_adk_events_ddl_clusters_on_session_id() -> None: @@ -68,6 +76,43 @@ def test_bigquery_adk_events_ddl_clusters_on_session_id() -> None: assert "test_dataset.adk_event" in ddl +def test_bigquery_adk_event_ttl_applies_only_to_event_partitions() -> None: + """Session partitions must not inherit event TTL expiration.""" + store = _make_store({"retention": {"event_ttl_seconds": 86400 * 30}}) + + session_ddl = asyncio.run(store._get_create_sessions_table_sql()) + event_ddl = asyncio.run(store._get_create_events_table_sql()) + + assert "partition_expiration_days" not in session_ddl + assert "partition_expiration_days = 30" in event_ddl + + +def test_bigquery_adk_explicit_partition_filter_adds_query_predicates(monkeypatch: pytest.MonkeyPatch) -> None: + """Opt-in partition-filter mode adds broad predicates to partitioned table DML.""" + store = _make_store({"bigquery": {"require_partition_filter": True}}) + statements: list[str] = [] + + def capture(_store: BigQueryADKStore, sql: str, parameters: Any = None) -> list[dict[str, Any]]: + statements.append(sql) + return [] + + monkeypatch.setattr(BigQueryADKStore, "_run_query", capture) + + store._get_session("app", "user", "session") + store._update_session_touch("app", "user", "session") + store._update_session_state("app", "user", "session", {"turn": 1}) + store._delete_session("app", "user", "session") + store._get_events("app", "user", "session") + store._delete_expired_events(datetime.now(timezone.utc)) + store._delete_idle_sessions(datetime.now(timezone.utc)) + + assert any("FROM test_dataset.adk_session" in sql and "create_time IS NOT NULL" in sql for sql in statements) + assert any("UPDATE test_dataset.adk_session" in sql and "create_time IS NOT NULL" in sql for sql in statements) + assert any("DELETE FROM test_dataset.adk_session" in sql and "create_time IS NOT NULL" in sql for sql in statements) + assert any("FROM test_dataset.adk_event" in sql and "timestamp IS NOT NULL" in sql for sql in statements) + assert any("DELETE FROM test_dataset.adk_event" in sql and "timestamp IS NOT NULL" in sql for sql in statements) + + def test_bigquery_adk_scoped_state_ddl_clustered() -> None: """Scoped-state tables cluster on their access keys.""" store = _make_store() diff --git a/tests/unit/extensions/test_adk/test_migrations.py b/tests/unit/extensions/test_adk/test_migrations.py index 36977d31b..2db87b255 100644 --- a/tests/unit/extensions/test_adk/test_migrations.py +++ b/tests/unit/extensions/test_adk/test_migrations.py @@ -36,6 +36,75 @@ def _index_of(statements: "list[str]", needle: str) -> int: raise AssertionError(msg) +class _Table: + def __init__(self, table_id: str) -> None: + self.table_id = table_id + + +class _SpannerDatabase: + def __init__(self, table_ids: "list[str]") -> None: + self._table_ids = table_ids + self.list_tables_calls = 0 + + def list_tables(self) -> "list[_Table]": + self.list_tables_calls += 1 + return [_Table(table_id) for table_id in self._table_ids] + + +class _SpannerConfig: + extension_config = {"adk": {"enable_memory": False}} + + def __init__(self, database: _SpannerDatabase) -> None: + self._database = database + + def get_database(self) -> _SpannerDatabase: + return self._database + + +class _SpannerDropStore: + connector_name = "spanner" + + def __init__(self, config: _SpannerConfig) -> None: + self._config = config + + async def _get_create_sessions_table_sql(self) -> str: + return "CREATE TABLE adk_session" + + async def _get_create_events_table_sql(self) -> str: + return "CREATE TABLE adk_event" + + async def _get_create_app_states_table_sql(self) -> str: + return "CREATE TABLE adk_app_state" + + async def _get_create_user_states_table_sql(self) -> str: + return "CREATE TABLE adk_user_state" + + async def _get_create_metadata_table_sql(self) -> str: + return "CREATE TABLE adk_metadata" + + async def _get_seed_metadata_sql(self) -> str: + return "INSERT INTO adk_metadata" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + "DROP TABLE adk_metadata", + "DROP TABLE adk_user_state", + "DROP TABLE adk_app_state", + "DROP TABLE adk_event", + "DROP TABLE adk_session", + ] + + +class _SpannerMemoryDropStore: + memory_table = "adk_memory" + + def __init__(self, config: _SpannerConfig) -> None: + self._config = config + + def _get_drop_memory_table_sql(self) -> "list[str]": + return ["DROP INDEX idx_adk_memory_session", "DROP TABLE adk_memory"] + + async def test_0001_up_is_noop_with_context() -> None: assert await migration_0001.up(_build_context()) == [] @@ -108,6 +177,57 @@ async def test_0002_up_with_no_memory_store_class_skips_memory_branch_entirely(m assert any("INSERT INTO adk_metadata" in stmt for stmt in statements) +async def test_0002_up_spanner_fresh_database_skips_missing_table_drops(monkeypatch: pytest.MonkeyPatch) -> None: + database = _SpannerDatabase([]) + context = MigrationContext(config=_SpannerConfig(database)) # type: ignore[arg-type] + monkeypatch.setattr(migration_0002, "_get_store_class", lambda _context: _SpannerDropStore) + monkeypatch.setattr(migration_0002, "_get_memory_store_class", lambda _context: None) + + statements = await migration_0002.up(context) + + assert database.list_tables_calls == 1 + assert all("DROP TABLE" not in statement for statement in statements) + assert statements[:6] == [ + "CREATE TABLE adk_session", + "CREATE TABLE adk_event", + "CREATE TABLE adk_app_state", + "CREATE TABLE adk_user_state", + "CREATE TABLE adk_metadata", + "INSERT INTO adk_metadata", + ] + + +async def test_0002_up_spanner_existing_database_keeps_fk_safe_drop_order(monkeypatch: pytest.MonkeyPatch) -> None: + database = _SpannerDatabase(["adk_session", "adk_event", "adk_app_state", "adk_user_state", "adk_metadata"]) + context = MigrationContext(config=_SpannerConfig(database)) # type: ignore[arg-type] + monkeypatch.setattr(migration_0002, "_get_store_class", lambda _context: _SpannerDropStore) + monkeypatch.setattr(migration_0002, "_get_memory_store_class", lambda _context: None) + + statements = await migration_0002.up(context) + + drops = [statement for statement in statements if statement.startswith("DROP TABLE")] + assert drops == [ + "DROP TABLE adk_metadata", + "DROP TABLE adk_user_state", + "DROP TABLE adk_app_state", + "DROP TABLE adk_event", + "DROP TABLE adk_session", + ] + + +async def test_0002_up_spanner_memory_drops_are_grouped_by_existing_memory_table( + monkeypatch: pytest.MonkeyPatch, +) -> None: + database = _SpannerDatabase(["adk_memory"]) + context = MigrationContext(config=_SpannerConfig(database)) # type: ignore[arg-type] + monkeypatch.setattr(migration_0002, "_get_store_class", lambda _context: _SpannerDropStore) + monkeypatch.setattr(migration_0002, "_get_memory_store_class", lambda _context: _SpannerMemoryDropStore) + + statements = await migration_0002.up(context) + + assert statements[:2] == ["DROP INDEX idx_adk_memory_session", "DROP TABLE adk_memory"] + + async def test_0002_down_with_memory_enabled_drops_memory_then_new_tables() -> None: statements = await migration_0002.down(_build_context()) diff --git a/tests/unit/extensions/test_adk/test_service.py b/tests/unit/extensions/test_adk/test_service.py index 0e43a12e3..eed678b43 100644 --- a/tests/unit/extensions/test_adk/test_service.py +++ b/tests/unit/extensions/test_adk/test_service.py @@ -1,7 +1,7 @@ """Unit tests for SQLSpecSessionService. Covers the durable-state write contract: -- append_event routes to append_event_and_update_state (or append_event on no-op) +- append_event routes all non-partial events to append_event_and_update_state - temp: keys are stripped before persisting session state - partial events are not persisted - create_session strips temp: keys from initial state @@ -46,6 +46,8 @@ def __init__(self) -> None: self.create_session_calls: list[dict[str, Any]] = [] self.upsert_app_state_calls: list[dict[str, Any]] = [] self.upsert_user_state_calls: list[dict[str, Any]] = [] + self.get_app_state_calls: list[str] = [] + self.get_user_state_calls: list[dict[str, str]] = [] self.app_state: dict[str, Any] | None = None self.user_state: dict[str, Any] | None = None @@ -135,9 +137,11 @@ async def get_events( return [] async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + self.get_app_state_calls.append(app_name) return self.app_state async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + self.get_user_state_calls.append({"app_name": app_name, "user_id": user_id}) return self.user_state async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: @@ -313,18 +317,54 @@ async def test_append_event_strips_temp_state_delta_from_persisted_state() -> No @pytest.mark.anyio async def test_append_event_routes_scoped_state_to_app_and_user_tables() -> None: - """app:* and user:* state is not stored in the per-session state blob.""" + """Scoped event deltas are merged with latest scoped rows before replacement.""" store = MockStore() + store._session_record["state"] = {"regular": "v0"} + store.app_state = {"app:counter": 1, "app:kept": "yes"} + store.user_state = {"user:theme": "light", "user:kept": "yes"} service = SQLSpecSessionService(store) # type: ignore[arg-type] - session = _make_session(state={"regular": "v0", "app:counter": 1, "user:theme": "light"}) + session = _make_session( + state={ + "regular": "v0", + "app:counter": "stale", + "app:kept": "stale", + "user:theme": "stale", + "user:kept": "stale", + } + ) event = _make_event(state_delta={"regular": "v1", "app:counter": 2, "user:theme": "dark"}) await service.append_event(session, event) persisted_state = store.append_event_and_update_state_calls[-1]["state"] assert persisted_state == {"regular": "v1"} - assert store.upsert_app_state_calls[-1] == {"app_name": "app", "state": {"app:counter": 2}} - assert store.upsert_user_state_calls[-1] == {"app_name": "app", "user_id": "u1", "state": {"user:theme": "dark"}} + assert store.upsert_app_state_calls[-1] == {"app_name": "app", "state": {"app:counter": 2, "app:kept": "yes"}} + assert store.upsert_user_state_calls[-1] == { + "app_name": "app", + "user_id": "u1", + "state": {"user:theme": "dark", "user:kept": "yes"}, + } + + +@pytest.mark.anyio +async def test_append_event_session_only_delta_does_not_write_stale_scoped_state() -> None: + """Merged in-memory scoped state is ignored unless the event delta touches that scope.""" + store = MockStore() + store._session_record["state"] = {"regular": "stored"} + store.app_state = {"app:counter": 1} + store.user_state = {"user:theme": "dark"} + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = _make_session(state={"regular": "loaded", "app:counter": "stale", "user:theme": "stale"}) + event = _make_event(state_delta={"regular": "updated"}) + + await service.append_event(session, event) + + last_call = store.append_event_and_update_state_calls[-1] + assert last_call["state"] == {"regular": "updated"} + assert last_call["app_state"] is None + assert last_call["user_state"] is None + assert store.upsert_app_state_calls == [] + assert store.upsert_user_state_calls == [] @pytest.mark.anyio @@ -401,16 +441,21 @@ async def test_append_event_returns_the_event() -> None: @pytest.mark.anyio async def test_append_event_skips_update_when_state_delta_empty() -> None: - """When event.actions.state_delta is empty, the session UPDATE is skipped.""" + """Event-only appends still touch the session row through the atomic path.""" store = MockStore() + store._session_record["state"] = {"key": "v"} service = SQLSpecSessionService(store) # type: ignore[arg-type] session = _make_session(state={"key": "v"}) event = _make_event(state_delta={}) await service.append_event(session, event) - assert len(store.append_event_calls) == 1, "no-op delta must route to append_event (event-only insert)" - assert not store.append_event_and_update_state_called, "no-op delta must NOT trigger append_event_and_update_state" + assert store.append_event_calls == [], "service-level events must use append_event_and_update_state" + assert store.append_event_and_update_state_called + last_call = store.append_event_and_update_state_calls[-1] + assert last_call["state"] == {"key": "v"} + assert last_call["app_state"] is None + assert last_call["user_state"] is None @pytest.mark.anyio @@ -429,7 +474,7 @@ async def test_append_event_runs_full_path_when_state_delta_present() -> None: @pytest.mark.anyio async def test_append_event_noop_event_record_carries_full_schema() -> None: - """The event_record sent to append_event on the no-op path carries the full key set.""" + """The event_record sent on the event-only path carries the full key set.""" store = MockStore() service = SQLSpecSessionService(store) # type: ignore[arg-type] session = _make_session() @@ -437,8 +482,7 @@ async def test_append_event_noop_event_record_carries_full_schema() -> None: await service.append_event(session, event) - assert len(store.append_event_calls) == 1 - event_record = store.append_event_calls[0] + event_record = store.append_event_and_update_state_calls[-1]["event_record"] assert set(event_record.keys()) == { "id", "app_name", @@ -452,15 +496,18 @@ async def test_append_event_noop_event_record_carries_full_schema() -> None: @pytest.mark.anyio async def test_append_event_noop_does_not_advance_last_update_time() -> None: - """On no-op, session.last_update_time matches the pre-append storage timestamp.""" + """Event-only appends advance session.last_update_time.""" store = MockStore() + original_update_time = datetime.now(timezone.utc) - timedelta(seconds=10) + store._session_record["update_time"] = original_update_time service = SQLSpecSessionService(store) # type: ignore[arg-type] session = _make_session(state={"key": "v"}) + session.last_update_time = original_update_time.timestamp() event = _make_event(state_delta={}) await service.append_event(session, event) - assert session.last_update_time == store._session_record["update_time"].timestamp() + assert session.last_update_time > original_update_time.timestamp() # --------------------------------------------------------------------------- @@ -497,6 +544,50 @@ async def test_create_session_routes_initial_user_scoped_state() -> None: assert store.upsert_user_state_calls[-1] == {"app_name": "app", "user_id": "u1", "state": {"user:theme": "dark"}} +@pytest.mark.anyio +async def test_create_session_empty_state_returns_existing_scoped_state() -> None: + """state={} still returns the shared app/user state visible to the new session.""" + store = MockStore() + store.app_state = {"app:counter": 3} + store.user_state = {"user:theme": "dark"} + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + session = await service.create_session(app_name="app", user_id="u1", state={}) + + assert store.create_session_calls[0]["state"] == {} + assert store.upsert_app_state_calls == [] + assert store.upsert_user_state_calls == [] + assert session.state == {"app:counter": 3, "user:theme": "dark"} + + +@pytest.mark.anyio +async def test_create_session_merges_initial_scoped_subset_over_existing_rows() -> None: + """Initial scoped keys replace only their keys in the existing scoped snapshots.""" + store = MockStore() + store.app_state = {"app:counter": 1, "app:kept": "yes"} + store.user_state = {"user:theme": "light", "user:kept": "yes"} + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + session = await service.create_session( + app_name="app", user_id="u1", state={"regular": "seed", "app:counter": 2, "user:theme": "dark"} + ) + + assert store.create_session_calls[0]["state"] == {"regular": "seed"} + assert store.upsert_app_state_calls[-1] == {"app_name": "app", "state": {"app:counter": 2, "app:kept": "yes"}} + assert store.upsert_user_state_calls[-1] == { + "app_name": "app", + "user_id": "u1", + "state": {"user:theme": "dark", "user:kept": "yes"}, + } + assert session.state == { + "regular": "seed", + "app:counter": 2, + "app:kept": "yes", + "user:theme": "dark", + "user:kept": "yes", + } + + @pytest.mark.anyio async def test_create_session_with_only_temp_state_persists_empty() -> None: """create_session with only temp: state persists empty state dict."""