diff --git a/agentrun/conversation_service/__ots_backend_async_template.py b/agentrun/conversation_service/__ots_backend_async_template.py index f4821bf..a2bca0c 100644 --- a/agentrun/conversation_service/__ots_backend_async_template.py +++ b/agentrun/conversation_service/__ots_backend_async_template.py @@ -42,6 +42,7 @@ DEFAULT_CONVERSATION_SECONDARY_INDEX, DEFAULT_CONVERSATION_TABLE, DEFAULT_EVENT_TABLE, + DEFAULT_STATE_SEARCH_INDEX, DEFAULT_STATE_TABLE, DEFAULT_USER_STATE_TABLE, StateData, @@ -97,15 +98,17 @@ def __init__( self._conversation_search_index = ( f"{table_prefix}{DEFAULT_CONVERSATION_SEARCH_INDEX}" ) + self._state_search_index = f"{table_prefix}{DEFAULT_STATE_SEARCH_INDEX}" # ----------------------------------------------------------------------- # 建表(异步)/ Table creation (async) # ----------------------------------------------------------------------- async def init_tables_async(self) -> None: - """创建五张表和 Conversation 二级索引(异步)。 + """创建五张表、二级索引和多元索引(异步)。 - 表已存在时跳过(catch OTSServiceError 并 log warning)。 + 包括 Conversation 二级索引、Conversation 多元索引和 State 多元索引。 + 表或索引已存在时跳过(catch OTSServiceError 并 log warning)。 """ await self._create_conversation_table_async() await self._create_event_table_async() @@ -125,6 +128,7 @@ async def init_tables_async(self) -> None: self._user_state_table, [("agent_id", "STRING"), ("user_id", "STRING")], ) + await self.init_search_index_async() async def init_core_tables_async(self) -> None: """创建核心表(Conversation + Event)和二级索引(异步)。""" @@ -151,8 +155,12 @@ async def init_state_tables_async(self) -> None: ) async def init_search_index_async(self) -> None: - """创建 Conversation 多元索引(异步)。按需调用。""" + """创建 Conversation 和 State 多元索引(异步)。 + + 索引已存在时跳过,可重复调用。 + """ await self._create_conversation_search_index_async() + await self._create_state_search_index_async() async def _create_conversation_table_async(self) -> None: """创建 Conversation 表 + 二级索引(异步)。""" @@ -383,6 +391,87 @@ async def _create_conversation_search_index_async(self) -> None: else: raise + async def _create_state_search_index_async(self) -> None: + """创建 State 表的多元索引(异步)。 + + 支持按 session_id 独立精确匹配查询,不受主键前缀限制。 + 索引已存在时跳过。 + """ + from tablestore import FieldType # type: ignore[import-untyped] + from tablestore import IndexSetting # type: ignore[import-untyped] + from tablestore import SortOrder # type: ignore[import-untyped] + from tablestore import FieldSchema + from tablestore import ( + FieldSort as OTSFieldSort, + ) # type: ignore[import-untyped] + from tablestore import SearchIndexMeta + from tablestore import Sort as OTSSort # type: ignore[import-untyped] + + fields = [ + FieldSchema( + "agent_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "user_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "session_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "created_at", + FieldType.LONG, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "updated_at", + FieldType.LONG, + index=True, + enable_sort_and_agg=True, + ), + ] + + index_setting = IndexSetting(routing_fields=["agent_id"]) + index_sort = OTSSort( + sorters=[OTSFieldSort("updated_at", sort_order=SortOrder.DESC)] + ) + index_meta = SearchIndexMeta( + fields, + index_setting=index_setting, + index_sort=index_sort, + ) + + try: + await self._async_client.create_search_index( + self._state_table, + self._state_search_index, + index_meta, + ) + logger.info( + "Created search index: %s on table: %s", + self._state_search_index, + self._state_table, + ) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Search index %s already exists, skipping.", + self._state_search_index, + ) + else: + raise + # ----------------------------------------------------------------------- # Session CRUD(异步)/ Session CRUD (async) # ----------------------------------------------------------------------- diff --git a/agentrun/conversation_service/__session_store_async_template.py b/agentrun/conversation_service/__session_store_async_template.py index 7f827a0..81ae8ca 100644 --- a/agentrun/conversation_service/__session_store_async_template.py +++ b/agentrun/conversation_service/__session_store_async_template.py @@ -35,7 +35,10 @@ def __init__(self, ots_backend: OTSBackend) -> None: self._backend = ots_backend async def init_tables_async(self) -> None: - """创建所有 OTS 表和索引(异步)。代理到 OTSBackend.init_tables_async()。""" + """创建所有 OTS 表、二级索引和多元索引(异步)。 + + 包括建表和创建搜索索引,无需再单独调用 init_search_index_async()。 + """ await self._backend.init_tables_async() async def init_core_tables_async(self) -> None: @@ -47,7 +50,10 @@ async def init_state_tables_async(self) -> None: await self._backend.init_state_tables_async() async def init_search_index_async(self) -> None: - """创建 Conversation 多元索引(异步)。按需调用。""" + """创建 Conversation 和 State 多元索引(异步)。 + + 索引已存在时跳过,可重复调用。 + """ await self._backend.init_search_index_async() # ------------------------------------------------------------------- diff --git a/agentrun/conversation_service/model.py b/agentrun/conversation_service/model.py index 96aceef..4866611 100644 --- a/agentrun/conversation_service/model.py +++ b/agentrun/conversation_service/model.py @@ -21,6 +21,7 @@ DEFAULT_USER_STATE_TABLE = "user_state" DEFAULT_CONVERSATION_SECONDARY_INDEX = "conversation_secondary_index" DEFAULT_CONVERSATION_SEARCH_INDEX = "conversation_search_index" +DEFAULT_STATE_SEARCH_INDEX = "state_search_index" # --------------------------------------------------------------------------- diff --git a/agentrun/conversation_service/ots_backend.py b/agentrun/conversation_service/ots_backend.py index 0a469be..3ccb0b9 100644 --- a/agentrun/conversation_service/ots_backend.py +++ b/agentrun/conversation_service/ots_backend.py @@ -52,6 +52,7 @@ DEFAULT_CONVERSATION_SECONDARY_INDEX, DEFAULT_CONVERSATION_TABLE, DEFAULT_EVENT_TABLE, + DEFAULT_STATE_SEARCH_INDEX, DEFAULT_STATE_TABLE, DEFAULT_USER_STATE_TABLE, StateData, @@ -107,15 +108,17 @@ def __init__( self._conversation_search_index = ( f"{table_prefix}{DEFAULT_CONVERSATION_SEARCH_INDEX}" ) + self._state_search_index = f"{table_prefix}{DEFAULT_STATE_SEARCH_INDEX}" # ----------------------------------------------------------------------- # 建表(异步)/ Table creation (async) # ----------------------------------------------------------------------- async def init_tables_async(self) -> None: - """创建五张表和 Conversation 二级索引(异步)。 + """创建五张表、二级索引和多元索引(异步)。 - 表已存在时跳过(catch OTSServiceError 并 log warning)。 + 包括 Conversation 二级索引、Conversation 多元索引和 State 多元索引。 + 表或索引已存在时跳过(catch OTSServiceError 并 log warning)。 """ await self._create_conversation_table_async() await self._create_event_table_async() @@ -135,11 +138,13 @@ async def init_tables_async(self) -> None: self._user_state_table, [("agent_id", "STRING"), ("user_id", "STRING")], ) + await self.init_search_index_async() def init_tables(self) -> None: - """创建五张表和 Conversation 二级索引(同步)。 + """创建五张表、二级索引和多元索引(同步)。 - 表已存在时跳过(catch OTSServiceError 并 log warning)。 + 包括 Conversation 二级索引、Conversation 多元索引和 State 多元索引。 + 表或索引已存在时跳过(catch OTSServiceError 并 log warning)。 """ self._create_conversation_table() self._create_event_table() @@ -159,6 +164,7 @@ def init_tables(self) -> None: self._user_state_table, [("agent_id", "STRING"), ("user_id", "STRING")], ) + self.init_search_index() async def init_core_tables_async(self) -> None: """创建核心表(Conversation + Event)和二级索引(异步)。""" @@ -209,12 +215,20 @@ def init_state_tables(self) -> None: ) async def init_search_index_async(self) -> None: - """创建 Conversation 多元索引(异步)。按需调用。""" + """创建 Conversation 和 State 多元索引(异步)。 + + 索引已存在时跳过,可重复调用。 + """ await self._create_conversation_search_index_async() + await self._create_state_search_index_async() def init_search_index(self) -> None: - """创建 Conversation 多元索引(同步)。按需调用。""" + """创建 Conversation 和 State 多元索引(同步)。 + + 索引已存在时跳过,可重复调用。 + """ self._create_conversation_search_index() + self._create_state_search_index() async def _create_conversation_table_async(self) -> None: """创建 Conversation 表 + 二级索引(异步)。""" @@ -567,6 +581,87 @@ async def _create_conversation_search_index_async(self) -> None: else: raise + async def _create_state_search_index_async(self) -> None: + """创建 State 表的多元索引(异步)。 + + 支持按 session_id 独立精确匹配查询,不受主键前缀限制。 + 索引已存在时跳过。 + """ + from tablestore import FieldType # type: ignore[import-untyped] + from tablestore import IndexSetting # type: ignore[import-untyped] + from tablestore import SortOrder # type: ignore[import-untyped] + from tablestore import FieldSchema + from tablestore import ( + FieldSort as OTSFieldSort, + ) # type: ignore[import-untyped] + from tablestore import SearchIndexMeta + from tablestore import Sort as OTSSort # type: ignore[import-untyped] + + fields = [ + FieldSchema( + "agent_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "user_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "session_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "created_at", + FieldType.LONG, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "updated_at", + FieldType.LONG, + index=True, + enable_sort_and_agg=True, + ), + ] + + index_setting = IndexSetting(routing_fields=["agent_id"]) + index_sort = OTSSort( + sorters=[OTSFieldSort("updated_at", sort_order=SortOrder.DESC)] + ) + index_meta = SearchIndexMeta( + fields, + index_setting=index_setting, + index_sort=index_sort, + ) + + try: + await self._async_client.create_search_index( + self._state_table, + self._state_search_index, + index_meta, + ) + logger.info( + "Created search index: %s on table: %s", + self._state_search_index, + self._state_table, + ) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Search index %s already exists, skipping.", + self._state_search_index, + ) + else: + raise + # ----------------------------------------------------------------------- # Session CRUD(异步)/ Session CRUD (async) # ----------------------------------------------------------------------- @@ -678,6 +773,87 @@ def _create_conversation_search_index(self) -> None: else: raise + def _create_state_search_index(self) -> None: + """创建 State 表的多元索引(同步)。 + + 支持按 session_id 独立精确匹配查询,不受主键前缀限制。 + 索引已存在时跳过。 + """ + from tablestore import FieldType # type: ignore[import-untyped] + from tablestore import IndexSetting # type: ignore[import-untyped] + from tablestore import SortOrder # type: ignore[import-untyped] + from tablestore import FieldSchema + from tablestore import ( + FieldSort as OTSFieldSort, + ) # type: ignore[import-untyped] + from tablestore import SearchIndexMeta + from tablestore import Sort as OTSSort # type: ignore[import-untyped] + + fields = [ + FieldSchema( + "agent_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "user_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "session_id", + FieldType.KEYWORD, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "created_at", + FieldType.LONG, + index=True, + enable_sort_and_agg=True, + ), + FieldSchema( + "updated_at", + FieldType.LONG, + index=True, + enable_sort_and_agg=True, + ), + ] + + index_setting = IndexSetting(routing_fields=["agent_id"]) + index_sort = OTSSort( + sorters=[OTSFieldSort("updated_at", sort_order=SortOrder.DESC)] + ) + index_meta = SearchIndexMeta( + fields, + index_setting=index_setting, + index_sort=index_sort, + ) + + try: + self._client.create_search_index( + self._state_table, + self._state_search_index, + index_meta, + ) + logger.info( + "Created search index: %s on table: %s", + self._state_search_index, + self._state_table, + ) + except OTSServiceError as e: + if "already exist" in str(e).lower() or ( + hasattr(e, "code") and e.code == "OTSObjectAlreadyExist" + ): + logger.warning( + "Search index %s already exists, skipping.", + self._state_search_index, + ) + else: + raise + # ----------------------------------------------------------------------- # Session CRUD(同步)/ Session CRUD (async) # ----------------------------------------------------------------------- diff --git a/agentrun/conversation_service/session_store.py b/agentrun/conversation_service/session_store.py index 49062a3..48378f8 100644 --- a/agentrun/conversation_service/session_store.py +++ b/agentrun/conversation_service/session_store.py @@ -45,11 +45,17 @@ def __init__(self, ots_backend: OTSBackend) -> None: self._backend = ots_backend async def init_tables_async(self) -> None: - """创建所有 OTS 表和索引(异步)。代理到 OTSBackend.init_tables_async()。""" + """创建所有 OTS 表、二级索引和多元索引(异步)。 + + 包括建表和创建搜索索引,无需再单独调用 init_search_index_async()。 + """ await self._backend.init_tables_async() def init_tables(self) -> None: - """创建所有 OTS 表和索引(同步)。代理到 OTSBackend.init_tables()。""" + """创建所有 OTS 表、二级索引和多元索引(同步)。 + + 包括建表和创建搜索索引,无需再单独调用 init_search_index()。 + """ self._backend.init_tables() async def init_core_tables_async(self) -> None: @@ -69,7 +75,10 @@ def init_state_tables(self) -> None: self._backend.init_state_tables() async def init_search_index_async(self) -> None: - """创建 Conversation 多元索引(异步)。按需调用。""" + """创建 Conversation 和 State 多元索引(异步)。 + + 索引已存在时跳过,可重复调用。 + """ await self._backend.init_search_index_async() # ------------------------------------------------------------------- @@ -77,7 +86,10 @@ async def init_search_index_async(self) -> None: # ------------------------------------------------------------------- def init_search_index(self) -> None: - """创建 Conversation 多元索引(同步)。按需调用。""" + """创建 Conversation 和 State 多元索引(同步)。 + + 索引已存在时跳过,可重复调用。 + """ self._backend.init_search_index() # ------------------------------------------------------------------- diff --git a/examples/conversation_service_adk_server.py b/examples/conversation_service_adk_server.py new file mode 100644 index 0000000..6b9a817 --- /dev/null +++ b/examples/conversation_service_adk_server.py @@ -0,0 +1,261 @@ +"""Google ADK Agent Server —— 使用 OTSSessionService 持久化会话。 + +集成步骤: + Step 1: 初始化 SessionStore(OTS 后端) + Step 2: 创建 OTSSessionService + Step 3: 创建 ADK Agent + Runner,传入 session_service + Step 4: 实现 invoke_agent,将 AgentRequest 转为 ADK 调用并流式输出 + Step 5: 通过 AgentRunServer 启动 HTTP 服务 + +使用方式: + uv run --env-file .env python examples/conversation_service_adk_server.py +""" + +from __future__ import annotations + +import os +import sys +from typing import Any +import uuid + +from dotenv import load_dotenv +from google.adk.agents import Agent # type: ignore[import-untyped] +from google.adk.models.lite_llm import LiteLlm +from google.adk.runners import Runner # type: ignore[import-untyped] +from google.adk.tools import ToolContext # type: ignore[import-untyped] +from google.genai import types # type: ignore[import-untyped] + +from agentrun import AgentRequest +from agentrun.conversation_service import SessionStore +from agentrun.conversation_service.adapters import OTSSessionService +from agentrun.server import AgentRunServer + +load_dotenv() + +# ── 配置参数 ────────────────────────────────────────────────── +APP_NAME = "adk_chat_server" +MEMORY_COLLECTION_NAME = os.getenv("MEMORY_COLLECTION_NAME", "") +DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY", "") + +if not MEMORY_COLLECTION_NAME: + print("ERROR: 请设置环境变量 MEMORY_COLLECTION_NAME") + sys.exit(1) +if not DASHSCOPE_API_KEY: + print("ERROR: 请设置环境变量 DASHSCOPE_API_KEY") + sys.exit(1) + + +# ── 工具定义 ────────────────────────────────────────────────── + + +def get_weather(city: str) -> dict[str, Any]: + """查询指定城市的天气信息。""" + data = { + "北京": {"weather": "晴", "temperature": "5~15°C"}, + "上海": {"weather": "多云", "temperature": "12~20°C"}, + } + return data.get(city, {"error": "暂无该城市数据"}) + + +def get_session_state(tool_context: ToolContext) -> dict[str, Any]: + """获取当前会话的状态信息。 + + 当用户询问会话状态、对话轮数、历史记录等信息时调用此工具。 + 返回 OTS 中持久化的完整 session state,包括: + - turn_count: 对话轮数 + - last_user_input: 上一轮用户输入 + - last_reply: 上一轮 agent 回复(由 output_key 自动写入) + - app:model_name: 使用的模型名称 + - user:language: 用户语言偏好 + """ + return tool_context.state.to_dict() + + +# ── Step 1: 初始化 SessionStore ────────────────────────────── + +store = SessionStore.from_memory_collection(MEMORY_COLLECTION_NAME) +store.init_tables() + +# ── Step 2: 创建 OTSSessionService ────────────────────────── + +session_service = OTSSessionService(session_store=store) + +# ── Step 3: 创建 ADK Agent + Runner ───────────────────────── + +custom_model = LiteLlm( + model="openai/qwen3-max", + api_key=DASHSCOPE_API_KEY, + api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", +) +agent = Agent( + name="smart_assistant", + model=custom_model, + instruction=( + "你是一个友好的中文智能助手。\n" + "- 用户问天气时调用 get_weather\n" + "- 用户询问会话状态、对话轮数、历史记录等信息时调用 get_session_state" + ), + tools=[get_weather, get_session_state], + output_key="last_reply", +) + +runner = Runner( + agent=agent, + app_name=APP_NAME, + session_service=session_service, +) + + +# ── 辅助函数 ────────────────────────────────────────────────── + + +def _get_session_id(req: AgentRequest) -> str: + """从请求 header 中提取 session_id,没有则生成一个。""" + raw_headers: dict[str, str] = {} + if hasattr(req, "raw_request") and req.raw_request: + raw_headers = dict(req.raw_request.headers) + + return ( + raw_headers.get("X-AgentRun-Session-ID") + or raw_headers.get("x-agentrun-session-id") + or raw_headers.get("X-Agentrun-Session-Id") + or f"chat_{uuid.uuid4().hex[:8]}" + ) + + +def _get_user_id(req: AgentRequest) -> str: + """从请求 header 中提取 user_id,没有则使用默认值。""" + raw_headers: dict[str, str] = {} + if hasattr(req, "raw_request") and req.raw_request: + raw_headers = dict(req.raw_request.headers) + + return ( + raw_headers.get("X-AgentRun-User-ID") + or raw_headers.get("x-agentrun-user-id") + or "default_user" + ) + + +async def _get_or_create_session(user_id: str, session_id: str) -> Any: + """获取已有 session,不存在则自动创建。 + + ADK Runner 需要一个已存在的 session 才能运行, + 所以首次请求时需要创建 session。 + """ + existing = await session_service.get_session( + app_name=APP_NAME, + user_id=user_id, + session_id=session_id, + ) + if existing is not None: + return existing + + return await session_service.create_session( + app_name=APP_NAME, + user_id=user_id, + session_id=session_id, + state={ + "app:model_name": custom_model.model, + "user:language": "zh-CN", + }, + ) + + +# ── Step 4: invoke_agent —— 核心 Server 处理函数 ───────────── + + +async def invoke_agent(req: AgentRequest): + """将 AgentRequest 转换为 ADK 调用并流式输出文本。 + + 流程: + 1. 从 header 提取 session_id / user_id + 2. 获取或创建 ADK session(不存在则自动创建) + 3. 打印当前 session 状态(展示 OTS 持久化效果) + 4. 取最后一条用户消息,转为 ADK Content + 5. 调用 runner.run_async() 流式输出 + 6. output_key="last_reply" 自动将回复写入状态 + 7. 手动更新额外状态(turn_count / last_user_input) + """ + session_id = _get_session_id(req) + user_id = _get_user_id(req) + + # 获取或创建 session + session = await _get_or_create_session(user_id, session_id) + + # ── 读取并展示当前 session 状态(体现 OTS 持久化能力) ──── + # + # 首次请求时状态为初始值;后续请求可以看到上一轮的 + # turn_count、last_user_input 以及 output_key 自动写入的 last_reply。 + turn_count = session.state.get("turn_count", 0) + print( + f"[Session {session.id}] " + f"turn_count={turn_count}, " + f"last_user_input={session.state.get('last_user_input', '(无)')}, " + f"last_reply={session.state.get('last_reply', '(无)')}" + ) + + # 提取最后一条用户消息 + last_user_text = "" + for msg in reversed(req.messages): + if msg.role == "user": + last_user_text = msg.content or "" + break + + if not last_user_text: + yield "请输入您的问题。" + return + + # 转换为 ADK Content 格式 + content = types.Content( + role="user", + parts=[types.Part(text=last_user_text)], + ) + + # 调用 ADK Runner 流式输出 + try: + async for event in runner.run_async( + user_id=user_id, + session_id=session.id, + new_message=content, + ): + if ( + event.is_final_response() + and event.content + and event.content.parts + ): + for part in event.content.parts: + if part.text: + yield part.text + + # ── 更新额外的 session 状态 ────────────────────────── + # + # output_key="last_reply" 已由 ADK Runner 自动将 agent 回复 + # 写入 session.state["last_reply"],此处额外记录: + # - turn_count: 对话轮数(递增) + # - last_user_input: 本轮用户输入 + await store.update_session_state_async( + APP_NAME, + user_id, + session.id, + { + "turn_count": turn_count + 1, + "last_user_input": last_user_text, + }, + ) + print(f"[Session {session.id}] 状态已更新: turn_count={turn_count + 1}") + + except Exception as e: + print(f"ADK Runner 执行异常: {e}") + raise Exception("Internal Error") + + +# ── Step 5: 启动 Server ────────────────────────────────────── + +if __name__ == "__main__": + server = AgentRunServer( + invoke_agent=invoke_agent, memory_collection_name=MEMORY_COLLECTION_NAME + ) + print(f"App Name: {APP_NAME}") + print(f"Memory Collection: {MEMORY_COLLECTION_NAME}") + print("请求时通过 X-AgentRun-Session-ID header 指定会话 ID") + server.start(port=9000) diff --git a/tests/unittests/conversation_service/test_ots_backend.py b/tests/unittests/conversation_service/test_ots_backend.py index ac21dba..3b22526 100644 --- a/tests/unittests/conversation_service/test_ots_backend.py +++ b/tests/unittests/conversation_service/test_ots_backend.py @@ -223,7 +223,7 @@ def test_init_search_index_success(self) -> None: client = _make_mock_client() backend = _make_backend(client) backend.init_search_index() - client.create_search_index.assert_called_once() + assert client.create_search_index.call_count == 2 def test_init_search_index_already_exist(self) -> None: client = _make_mock_client() @@ -237,8 +237,16 @@ def test_init_search_index_already_exist(self) -> None: def test_init_search_index_other_error(self) -> None: client = _make_mock_client() - err = OTSServiceError(500, "InternalError", "internal error") - client.create_search_index.side_effect = err + call_count = 0 + + def _side_effect(*args: object, **kwargs: object) -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + return None + raise OTSServiceError(500, "InternalError", "internal error") + + client.create_search_index.side_effect = _side_effect backend = _make_backend(client) with pytest.raises(OTSServiceError): @@ -252,6 +260,7 @@ def test_table_prefix(self) -> None: assert backend._state_table == "myprefix_state" assert backend._app_state_table == "myprefix_app_state" assert backend._user_state_table == "myprefix_user_state" + assert backend._state_search_index == "myprefix_state_search_index" # --------------------------------------------------------------------------- @@ -1225,6 +1234,7 @@ async def test_init_tables_already_exist(self) -> None: async_client = MagicMock() err = OTSServiceError(409, "OTSObjectAlreadyExist", "already exist") async_client.create_table = AsyncMock(side_effect=err) + async_client.create_search_index = AsyncMock(side_effect=err) backend = _make_async_backend(async_client) await backend.init_tables_async() @@ -1233,6 +1243,7 @@ async def test_init_tables_other_error(self) -> None: async_client = MagicMock() err = OTSServiceError(500, "InternalError", "error") async_client.create_table = AsyncMock(side_effect=err) + async_client.create_search_index = AsyncMock(side_effect=err) backend = _make_async_backend(async_client) with pytest.raises(OTSServiceError): await backend.init_tables_async() @@ -1253,7 +1264,7 @@ async def test_init_state_tables(self) -> None: async def test_init_search_index(self) -> None: backend = _make_async_backend() await backend.init_search_index_async() - backend._async_client.create_search_index.assert_called_once() + assert backend._async_client.create_search_index.call_count == 2 @pytest.mark.asyncio async def test_init_search_index_already_exist(self) -> None: @@ -1268,8 +1279,16 @@ async def test_init_search_index_already_exist(self) -> None: async def test_init_search_index_other_error(self) -> None: async_client = MagicMock() async_client.create_table = AsyncMock() - err = OTSServiceError(500, "InternalError", "error") - async_client.create_search_index = AsyncMock(side_effect=err) + call_count = 0 + + async def _side_effect(*args: object, **kwargs: object) -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + return None + raise OTSServiceError(500, "InternalError", "error") + + async_client.create_search_index = AsyncMock(side_effect=_side_effect) backend = _make_async_backend(async_client) with pytest.raises(OTSServiceError): await backend.init_search_index_async()