diff --git a/python/copilot/client.py b/python/copilot/client.py index ff587d99..f515e510 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -21,7 +21,7 @@ import threading from collections.abc import Callable from pathlib import Path -from typing import Any, cast +from typing import Any, cast, overload from .generated.rpc import ServerRpc from .generated.session_events import PermissionRequest, session_event_from_dict @@ -49,6 +49,8 @@ ToolResult, ) +HandlerUnsubcribe = Callable[[], None] + # Minimum protocol version this SDK can communicate with. # Servers reporting a version below this are rejected. MIN_PROTOCOL_VERSION = 2 @@ -921,13 +923,16 @@ async def list_models(self) -> list["ModelInfo"]: if self._models_cache is not None: return list(self._models_cache) # Return a copy to prevent cache mutation + models: list[ModelInfo] if self._on_list_models: # Use custom handler instead of CLI RPC result = self._on_list_models() + # cast needed: inspect.isawaitable isn't a type guard, so the + # linter can't narrow list[ModelInfo] | Awaitable[list[ModelInfo]] if inspect.isawaitable(result): - models = await result + models = cast(list[ModelInfo], await result) else: - models = result + models = cast(list[ModelInfo], result) else: if not self._client: raise RuntimeError("Client not connected") @@ -1087,11 +1092,20 @@ async def set_foreground_session_id(self, session_id: str) -> None: error = response.get("error", "Unknown error") raise RuntimeError(f"Failed to set foreground session: {error}") + @overload + def on(self, handler: SessionLifecycleHandler, /) -> HandlerUnsubcribe: ... + + @overload + def on( + self, event_type: SessionLifecycleEventType, /, handler: SessionLifecycleHandler + ) -> HandlerUnsubcribe: ... + def on( self, event_type_or_handler: SessionLifecycleEventType | SessionLifecycleHandler, + /, handler: SessionLifecycleHandler | None = None, - ) -> Callable[[], None]: + ) -> HandlerUnsubcribe: """ Subscribe to session lifecycle events. diff --git a/test/scenarios/auth/byok-ollama/python/main.py b/test/scenarios/auth/byok-ollama/python/main.py index b86c76ba..e54a61be 100644 --- a/test/scenarios/auth/byok-ollama/python/main.py +++ b/test/scenarios/auth/byok-ollama/python/main.py @@ -1,6 +1,5 @@ import asyncio import os -import sys from copilot import CopilotClient OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434/v1")