Skip to content
22 changes: 18 additions & 4 deletions python/copilot/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import threading
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
from typing import Any, cast, overload

from .generated.rpc import ServerRpc
from .generated.session_events import PermissionRequest, session_event_from_dict
Expand Down Expand Up @@ -49,6 +49,8 @@
ToolResult,
)

HandlerUnsubcribe = Callable[[], None]

# Minimum protocol version this SDK can communicate with.
# Servers reporting a version below this are rejected.
MIN_PROTOCOL_VERSION = 2
Expand Down Expand Up @@ -921,13 +923,16 @@ async def list_models(self) -> list["ModelInfo"]:
if self._models_cache is not None:
return list(self._models_cache) # Return a copy to prevent cache mutation

models: list[ModelInfo]
if self._on_list_models:
# Use custom handler instead of CLI RPC
result = self._on_list_models()
# cast needed: inspect.isawaitable isn't a type guard, so the
# linter can't narrow list[ModelInfo] | Awaitable[list[ModelInfo]]
if inspect.isawaitable(result):
models = await result
models = cast(list[ModelInfo], await result)
else:
models = result
models = cast(list[ModelInfo], result)
else:
if not self._client:
raise RuntimeError("Client not connected")
Expand Down Expand Up @@ -1087,11 +1092,20 @@ async def set_foreground_session_id(self, session_id: str) -> None:
error = response.get("error", "Unknown error")
raise RuntimeError(f"Failed to set foreground session: {error}")

@overload
def on(self, handler: SessionLifecycleHandler, /) -> HandlerUnsubcribe: ...

@overload
def on(
self, event_type: SessionLifecycleEventType, /, handler: SessionLifecycleHandler
) -> HandlerUnsubcribe: ...

def on(
self,
event_type_or_handler: SessionLifecycleEventType | SessionLifecycleHandler,
/,
handler: SessionLifecycleHandler | None = None,
) -> Callable[[], None]:
) -> HandlerUnsubcribe:
"""
Subscribe to session lifecycle events.

Expand Down
1 change: 0 additions & 1 deletion test/scenarios/auth/byok-ollama/python/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import os
import sys
from copilot import CopilotClient

OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434/v1")
Expand Down