diff --git a/assemblyai/streaming/v3/__init__.py b/assemblyai/streaming/v3/__init__.py index e89ad55..e351020 100644 --- a/assemblyai/streaming/v3/__init__.py +++ b/assemblyai/streaming/v3/__init__.py @@ -1,3 +1,4 @@ +from .async_client import AsyncStreamingClient from .client import StreamingClient from .models import ( BeginEvent, @@ -21,6 +22,7 @@ ) __all__ = [ + "AsyncStreamingClient", "BeginEvent", "Encoding", "EventMessage", diff --git a/assemblyai/streaming/v3/async_client.py b/assemblyai/streaming/v3/async_client.py new file mode 100644 index 0000000..3f36d50 --- /dev/null +++ b/assemblyai/streaming/v3/async_client.py @@ -0,0 +1,486 @@ +import asyncio +import collections.abc +import inspect +import json +import logging +import sys +from typing import ( + Any, + AsyncIterable, + Callable, + Dict, + Iterable, + List, + Optional, + Union, +) + +import httpx +import websockets +from pydantic import BaseModel + +# Prefer the new asyncio client API (websockets >= 13). Fall back to the legacy +# top-level connect for older versions the SDK still supports per ``setup.py`` +# (``websockets>=11.0``). The two APIs differ only in the header-kwarg name +# (``additional_headers`` vs ``extra_headers``); the ``websocket_connect_async`` +# wrapper below papers that over so tests and callers see one entry point. +try: + from websockets.asyncio.client import connect as _ws_connect + + _WS_HEADER_KW = "additional_headers" +except ImportError: # pragma: no cover - exercised on websockets <13 only + from websockets.client import connect as _ws_connect # type: ignore[no-redef] + + _WS_HEADER_KW = "extra_headers" + +from assemblyai import __version__ + +from .client import ( + _build_headers, + _build_uri, + _dump_model, + _dump_model_json, + _emit_param_warnings, + _normalize_min_turn_silence, + _parse_model, +) +from .models import ( + BeginEvent, + ErrorEvent, + EventMessage, + ForceEndpoint, + LLMGatewayResponseEvent, + OperationMessage, + SpeechStartedEvent, + StreamingClientOptions, + StreamingError, + StreamingErrorCodes, + StreamingEvents, + StreamingParameters, + StreamingSessionParameters, + TerminateSession, + TerminationEvent, + TurnEvent, + UpdateConfiguration, + WarningEvent, +) + +logger = logging.getLogger(__name__) + + +def websocket_connect_async(uri: str, additional_headers): + """Open a websocket connection using whichever ``websockets`` API is + available. Returns the underlying ``Connect`` awaitable so callers may + ``await`` it directly (or wrap in ``asyncio.wait_for``). Module-level + indirection so tests can patch a single attribute.""" + return _ws_connect(uri, **{_WS_HEADER_KW: additional_headers}) + + +class AsyncStreamingClient: + """Asyncio-native counterpart to ``StreamingClient``. + + The public API mirrors the thread-based client one-to-one — same options, + parameters, events, and event-handler registration. Methods that touch the + network are coroutines. Event handlers may be plain callables or + coroutine functions; coroutine handlers are awaited inline by the single + internal read task. Handlers should therefore avoid indefinite blocking, + just as with the sync client. + """ + + def __init__(self, options: StreamingClientOptions): + self._options = options + + self._client = _AsyncHTTPClient( + api_host=options.api_host, api_key=options.api_key + ) + + self._handlers: Dict[StreamingEvents, List[Callable]] = {} + for event in StreamingEvents.__members__.values(): + self._handlers[event] = [] + + self._write_queue: "asyncio.Queue[OperationMessage]" = asyncio.Queue() + self._stop_event = asyncio.Event() + + # Dedup flags. Only ``_read_task`` mutates these — the write task on + # ``ConnectionClosed`` just logs + sets the stop event + exits. Asyncio's + # ``await ws.recv()`` raises ``ConnectionClosed`` as soon as the socket + # transitions to closed, so the read task always sees the close + # naturally — no cross-task error hand-off is required. + self._connection_closed_reported = False + self._server_error_reported = False + + self._websocket: Optional[Any] = None + self._read_task: Optional[asyncio.Task] = None + self._write_task: Optional[asyncio.Task] = None + + async def connect(self, params: StreamingParameters) -> None: + if self._websocket is not None or ( + self._read_task is not None and not self._read_task.done() + ): + raise RuntimeError( + "AsyncStreamingClient is already connected; " + "create a new instance for a new connection." + ) + + _emit_param_warnings(params) + + uri = _build_uri(self._options.api_host, params) + headers = _build_headers(self._options) + + try: + self._websocket = await asyncio.wait_for( + websocket_connect_async(uri, additional_headers=headers), + timeout=15, + ) + except websockets.exceptions.InvalidStatus as exc: + status_code = getattr(getattr(exc, "response", None), "status_code", None) + await self._report_connection_closed( + StreamingError( + message=f"WebSocket handshake rejected (HTTP {status_code})", + code=status_code, + ) + ) + return + except ( + websockets.exceptions.InvalidHandshake, + websockets.exceptions.ConnectionClosed, + OSError, + asyncio.TimeoutError, + TimeoutError, + ) as exc: + await self._report_connection_closed(exc) + return + + self._read_task = asyncio.create_task( + self._read_loop(), name="AsyncStreamingClient._read_loop" + ) + self._write_task = asyncio.create_task( + self._write_loop(), name="AsyncStreamingClient._write_loop" + ) + + logger.debug("Connected to WebSocket server") + + async def disconnect(self, terminate: bool = False) -> None: + if terminate and not self._stop_event.is_set(): + await self._write_queue.put(TerminateSession()) + # Let the write task drain TerminateSession and exit naturally + # before we set stop / cancel below. ``asyncio.wait`` does not + # cancel the awaited task on timeout, unlike ``wait_for``. + if self._write_task is not None and not self._write_task.done(): + await asyncio.wait({self._write_task}, timeout=2.0) + + self._stop_event.set() + + current = asyncio.current_task() + for task in (self._read_task, self._write_task): + if task is None or task is current or task.done(): + continue + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Streaming task raised during disconnect") + + await self._close_websocket() + await self._client.aclose() + + async def _close_websocket(self) -> None: + if not self._websocket: + return + try: + await self._websocket.close() + except (OSError, websockets.exceptions.WebSocketException) as exc: + logger.debug("Error closing websocket: %s", exc) + + async def stream( + self, + data: Union[bytes, AsyncIterable[bytes], Iterable[bytes]], + ) -> None: + if self._stop_event.is_set(): + return + + if isinstance(data, bytes): + await self._write_queue.put(data) + return + + if isinstance(data, collections.abc.AsyncIterable): + async for chunk in data: + if self._stop_event.is_set(): + return + await self._write_queue.put(chunk) + return + + for chunk in data: + if self._stop_event.is_set(): + return + await self._write_queue.put(chunk) + + async def set_params(self, params: StreamingSessionParameters) -> None: + message_dict = _normalize_min_turn_silence(_dump_model(params)) + message = UpdateConfiguration(**message_dict) + await self._write_queue.put(message) + + async def force_endpoint(self) -> None: + await self._write_queue.put(ForceEndpoint()) + + def on(self, event: StreamingEvents, handler: Callable) -> None: + if event in StreamingEvents.__members__.values() and callable(handler): + self._handlers[event].append(handler) + + async def _write_loop(self) -> None: + while True: + if not self._websocket: + raise ValueError("Not connected to the WebSocket server") + + try: + data = await asyncio.wait_for(self._write_queue.get(), timeout=1) + except asyncio.TimeoutError: + if self._stop_event.is_set(): + return + continue + + # TerminateSession bypasses the stop gate so disconnect(terminate=True) + # can always send it, even when stop is set between put() and the + # write loop's next iteration. + is_terminate = isinstance(data, TerminateSession) + if not is_terminate and self._stop_event.is_set(): + return + + try: + if isinstance(data, bytes): + await self._websocket.send(data) + elif isinstance(data, BaseModel): + await self._websocket.send(_dump_model_json(data)) + else: + raise ValueError(f"Attempted to send invalid message: {type(data)}") + except websockets.exceptions.ConnectionClosed as exc: + # Dispatch the close directly from the write task. The read + # task may short-circuit on ``_stop_event`` at the top of its + # loop (e.g. while a buffered message was processed between + # ``recv()`` calls) and never observe the close in ``recv()``, + # so the write task can't rely on it to dispatch. + # ``_report_connection_closed`` is idempotent — its flag check + # + set is synchronous (no ``await`` between them), so if the + # read task also raises ``ConnectionClosed`` it'll be a no-op. + await self._report_connection_closed(exc) + return + + if is_terminate: + return + + async def _read_loop(self) -> None: + while True: + if not self._websocket: + raise ValueError("Not connected to the WebSocket server") + + if self._stop_event.is_set(): + return + + try: + message_data = await self._websocket.recv() + except websockets.exceptions.ConnectionClosed as exc: + await self._report_connection_closed(exc) + return + + try: + message_json = json.loads(message_data) + except json.JSONDecodeError as exc: + logger.warning(f"Failed to decode message: {exc}") + continue + + message = self._parse_message(message_json) + + if isinstance(message, ErrorEvent): + await self._report_server_error(message) + elif isinstance(message, WarningEvent): + await self._handle_warning(message) + elif message: + await self._handle_message(message) + else: + logger.warning(f"Unsupported event type: {message_json.get('type')}") + + async def _handle_message(self, message: EventMessage) -> None: + if isinstance(message, TerminationEvent): + self._stop_event.set() + + event_type = StreamingEvents[message.type] + + for handler in self._handlers[event_type]: + await self._invoke_handler(handler, message) + + async def _handle_warning(self, warning: WarningEvent) -> None: + logger.warning( + "Streaming warning (code=%s): %s", warning.warning_code, warning.warning + ) + for handler in self._handlers[StreamingEvents.Warning]: + await self._invoke_handler(handler, warning) + + def _parse_message(self, data: Dict[str, Any]) -> Optional[EventMessage]: + if "type" in data: + event_type = self._parse_event_type(data.get("type")) + + if event_type == StreamingEvents.Begin: + return _parse_model(BeginEvent, data) + elif event_type == StreamingEvents.Termination: + return _parse_model(TerminationEvent, data) + elif event_type == StreamingEvents.Turn: + return _parse_model(TurnEvent, data) + elif event_type == StreamingEvents.SpeechStarted: + return _parse_model(SpeechStartedEvent, data) + elif event_type == StreamingEvents.LLMGatewayResponse: + return _parse_model(LLMGatewayResponseEvent, data) + elif event_type == StreamingEvents.Error: + return _parse_model(ErrorEvent, data) + elif event_type == StreamingEvents.Warning: + return _parse_model(WarningEvent, data) + else: + return None + elif "error" in data: + return _parse_model(ErrorEvent, data) + return None + + @staticmethod + def _parse_event_type(message_type: Optional[Any]) -> Optional[StreamingEvents]: + if not isinstance(message_type, str): + return None + try: + return StreamingEvents[message_type] + except KeyError: + return None + + async def _report_server_error(self, error: ErrorEvent) -> None: + self._server_error_reported = True + streaming_error = StreamingError(message=error.error, code=error.error_code) + logger.error("Streaming error: %s (code=%s)", error.error, error.error_code) + await self._dispatch_error(streaming_error) + + async def _report_connection_closed( + self, + error: Union[ + StreamingError, + ErrorEvent, + websockets.exceptions.ConnectionClosed, + OSError, + ], + ) -> None: + if self._connection_closed_reported: + return + self._connection_closed_reported = True + self._stop_event.set() + + streaming_error = self._build_connection_closed_error(error) + + if streaming_error is None: + await self._close_websocket() + return + + if isinstance(error, websockets.exceptions.ConnectionClosed): + reason = error.reason or "no reason given" + logger.error("Connection closed: %s (code=%s)", reason, error.code) + else: + logger.error( + "Connection failed: %s (code=%s)", + streaming_error, + streaming_error.code, + ) + + # If a server Error frame already fired on_error, the close is the + # effect, not a new cause — log it (above) but skip the duplicate + # user-visible error. + if not self._server_error_reported: + await self._dispatch_error(streaming_error) + + await self._close_websocket() + + async def _dispatch_error(self, error: StreamingError) -> None: + for handler in self._handlers[StreamingEvents.Error]: + try: + await self._invoke_handler(handler, error) + except Exception: + logger.exception("on_error handler raised") + + async def _invoke_handler(self, handler: Callable, payload: Any) -> None: + try: + result = handler(self, payload) + if inspect.isawaitable(result): + await result + except Exception: + logger.exception("Streaming handler raised") + + @staticmethod + def _build_connection_closed_error( + error: Union[ + StreamingError, + ErrorEvent, + websockets.exceptions.ConnectionClosed, + OSError, + ], + ) -> Optional[StreamingError]: + if isinstance(error, StreamingError): + return error + if isinstance(error, ErrorEvent): + return StreamingError(message=error.error, code=error.error_code) + if isinstance(error, websockets.exceptions.ConnectionClosed): + if error.code == 1000: + return None + if error.code is not None and error.code in StreamingErrorCodes: + message = StreamingErrorCodes[error.code] + else: + message = error.reason or f"Connection closed (code={error.code})" + return StreamingError(message=message, code=error.code) + return StreamingError(message=f"Connection failed: {error}") + + async def create_temporary_token( + self, + expires_in_seconds: int, + max_session_duration_seconds: Optional[int] = None, + ) -> str: + return await self._client.create_temporary_token( + expires_in_seconds=expires_in_seconds, + max_session_duration_seconds=max_session_duration_seconds, + ) + + +class _AsyncHTTPClient: + def __init__(self, api_host: str, api_key: Optional[str] = None): + vi = sys.version_info + python_version = f"{vi.major}.{vi.minor}.{vi.micro}" + user_agent = ( + f"{httpx._client.USER_AGENT} AssemblyAI/1.0 " + f"(sdk=Python/{__version__} runtime_env=Python/{python_version})" + ) + + headers = {"User-Agent": user_agent} + + if api_key: + headers["Authorization"] = api_key + + self._http_client = httpx.AsyncClient( + base_url="https://" + api_host, + headers=headers, + ) + + async def create_temporary_token( + self, + expires_in_seconds: int, + max_session_duration_seconds: Optional[int] = None, + ) -> str: + params: Dict[str, Any] = {} + + if expires_in_seconds: + params["expires_in_seconds"] = expires_in_seconds + + if max_session_duration_seconds: + params["max_session_duration_seconds"] = max_session_duration_seconds + + response = await self._http_client.get("/v3/token", params=params) + response.raise_for_status() + return response.json()["token"] + + async def aclose(self) -> None: + try: + await self._http_client.aclose() + except Exception as exc: + logger.debug("Error closing async HTTP client: %s", exc) diff --git a/assemblyai/streaming/v3/client.py b/assemblyai/streaming/v3/client.py index a9d2581..1913c14 100644 --- a/assemblyai/streaming/v3/client.py +++ b/assemblyai/streaming/v3/client.py @@ -110,6 +110,46 @@ def _user_agent() -> str: ) +def _emit_param_warnings(params: StreamingParameters) -> None: + if params.speech_model == "u3-pro": + logger.warning( + "[Deprecation Warning] The speech model `u3-pro` is deprecated and will be removed in a future release. " + "Please use `u3-rt-pro` instead." + ) + if params.customer_support_audio_capture: + logger.warning( + "`customer_support_audio_capture=True` will record session audio. " + "Only enable this when explicitly coordinating with AssemblyAI support." + ) + + +def _build_uri(host: str, params: StreamingParameters) -> str: + params_dict = _normalize_voice_focus( + _normalize_min_turn_silence(_dump_model(params)) + ) + # JSON-encode list and dict parameters for proper API compatibility (e.g., + # keyterms_prompt, llm_gateway) + for key, value in params_dict.items(): + if isinstance(value, list): + params_dict[key] = json.dumps(value) + elif isinstance(value, dict): + params_dict[key] = json.dumps(value) + + params_encoded = urlencode(params_dict) + + if host.startswith(("ws://", "wss://")): + return f"{host}/v3/ws?{params_encoded}" + return f"wss://{host}/v3/ws?{params_encoded}" + + +def _build_headers(options: StreamingClientOptions) -> Dict[str, str]: + return { + "Authorization": options.token or options.api_key or "", + "User-Agent": _user_agent(), + "AssemblyAI-Version": "2025-05-12", + } + + class StreamingClient: def __init__(self, options: StreamingClientOptions): self._options = options @@ -138,43 +178,10 @@ def __init__(self, options: StreamingClientOptions): self._websocket = None def connect(self, params: StreamingParameters) -> None: - if params.speech_model == "u3-pro": - logger.warning( - "[Deprecation Warning] The speech model `u3-pro` is deprecated and will be removed in a future release. " - "Please use `u3-rt-pro` instead." - ) - - if params.customer_support_audio_capture: - logger.warning( - "`customer_support_audio_capture=True` will record session audio. " - "Only enable this when explicitly coordinating with AssemblyAI support." - ) + _emit_param_warnings(params) - params_dict = _normalize_voice_focus( - _normalize_min_turn_silence(_dump_model(params)) - ) - - # JSON-encode list and dict parameters for proper API compatibility (e.g., keyterms_prompt, llm_gateway) - for key, value in params_dict.items(): - if isinstance(value, list): - params_dict[key] = json.dumps(value) - elif isinstance(value, dict): - params_dict[key] = json.dumps(value) - - params_encoded = urlencode(params_dict) - - host = self._options.api_host - if host.startswith(("ws://", "wss://")): - uri = f"{host}/v3/ws?{params_encoded}" - else: - uri = f"wss://{host}/v3/ws?{params_encoded}" - headers = { - "Authorization": self._options.token - if self._options.token - else self._options.api_key, - "User-Agent": _user_agent(), - "AssemblyAI-Version": "2025-05-12", - } + uri = _build_uri(self._options.api_host, params) + headers = _build_headers(self._options) try: self._websocket = websocket_connect( diff --git a/tests/unit/test_streaming_async.py b/tests/unit/test_streaming_async.py new file mode 100644 index 0000000..05fa165 --- /dev/null +++ b/tests/unit/test_streaming_async.py @@ -0,0 +1,578 @@ +import asyncio +import json +import logging +from urllib.parse import urlencode + +import pytest +from pytest_mock import MockFixture +from websockets.exceptions import ConnectionClosed, InvalidStatus +from websockets.frames import Close + +from assemblyai.streaming.v3 import ( + AsyncStreamingClient, + SpeechModel, + StreamingClientOptions, + StreamingEvents, + StreamingParameters, +) +from assemblyai.streaming.v3.models import TerminateSession + +pytestmark = pytest.mark.asyncio + + +def _default_params() -> StreamingParameters: + return StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.universal_streaming_english, + ) + + +class _FakeAsyncWebSocket: + """Programmable async websocket stand-in for driving AsyncStreamingClient + in tests. Inbound messages are queued via ``push_message`` / + ``push_close``; outbound sends accumulate in ``sent``. + """ + + def __init__(self, send_raises=None): + self._inbound: "asyncio.Queue[object]" = asyncio.Queue() + self._send_raises = send_raises + self.sent: list = [] + self.send_call_count = 0 + self.close_call_count = 0 + self._closed = False + + def push_message(self, data) -> None: + self._inbound.put_nowait(data) + + def push_close(self, exc: BaseException) -> None: + self._inbound.put_nowait(exc) + + async def recv(self): + item = await self._inbound.get() + if isinstance(item, BaseException): + raise item + return item + + async def send(self, data) -> None: + self.send_call_count += 1 + if self._send_raises is not None: + raise self._send_raises + self.sent.append(data) + + async def close(self) -> None: + self.close_call_count += 1 + self._closed = True + + +def _patch_connect(mocker: MockFixture, fake_ws): + """Patch ``websocket_connect_async`` to return the given fake websocket.""" + + async def fake_connect(uri, additional_headers=None, **_kwargs): + fake_connect.uri = uri + fake_connect.additional_headers = additional_headers + return fake_ws + + fake_connect.uri = None + fake_connect.additional_headers = None + mocker.patch( + "assemblyai.streaming.v3.async_client.websocket_connect_async", + new=fake_connect, + ) + return fake_connect + + +async def _wait_for_tasks(client: AsyncStreamingClient, timeout: float = 2.0) -> None: + """Wait until both read/write tasks have exited and stop is set. Raises + ``AssertionError`` on timeout so stalls fail tests deterministically + instead of silently passing.""" + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + while loop.time() < deadline: + read_done = client._read_task is None or client._read_task.done() + write_done = client._write_task is None or client._write_task.done() + if read_done and write_done and client._stop_event.is_set(): + return + await asyncio.sleep(0.01) + raise AssertionError( + f"AsyncStreamingClient read/write tasks did not finish within {timeout}s" + ) + + +async def test_client_connect_builds_uri_and_headers(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + fake_connect = _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + + params = _default_params() + await client.connect(params) + + expected_qs = urlencode( + { + "sample_rate": params.sample_rate, + "speech_model": str(params.speech_model), + } + ) + assert fake_connect.uri == f"wss://api.example.com/v3/ws?{expected_qs}" + assert fake_connect.additional_headers["Authorization"] == "test" + assert fake_connect.additional_headers["AssemblyAI-Version"] == "2025-05-12" + assert "AssemblyAI/1.0" in fake_connect.additional_headers["User-Agent"] + + await client.disconnect() + + +async def test_client_connect_with_token(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + fake_connect = _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(token="tok-value", api_host="api.example.com") + ) + await client.connect(_default_params()) + + assert fake_connect.additional_headers["Authorization"] == "tok-value" + + await client.disconnect() + + +async def test_stream_bytes_writes_to_socket(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + await client.stream(b"\x00" * 320) + + # Give the write task a moment to drain the queue. + for _ in range(50): + if fake_ws.sent: + break + await asyncio.sleep(0.01) + + assert fake_ws.sent == [b"\x00" * 320] + + await client.disconnect() + + +async def test_stream_sync_iterable(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + chunks = [b"a", b"bb", b"ccc"] + await client.stream(iter(chunks)) + + for _ in range(50): + if len(fake_ws.sent) == 3: + break + await asyncio.sleep(0.01) + + assert fake_ws.sent == chunks + + await client.disconnect() + + +async def test_stream_async_iterable(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + async def gen(): + for chunk in (b"x", b"yy", b"zzz"): + yield chunk + + await client.stream(gen()) + + for _ in range(50): + if len(fake_ws.sent) == 3: + break + await asyncio.sleep(0.01) + + assert fake_ws.sent == [b"x", b"yy", b"zzz"] + + await client.disconnect() + + +async def test_disconnect_terminate_sends_terminate_then_closes(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + await client.disconnect(terminate=True) + + sent_terminate = [ + s for s in fake_ws.sent if isinstance(s, str) and "Terminate" in s + ] + assert len(sent_terminate) == 1 + assert fake_ws.close_call_count >= 1 + + +async def test_begin_event_dispatched_to_handler(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + received = [] + + def on_begin(_client, event): + received.append(event) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Begin, on_begin) + await client.connect(_default_params()) + + fake_ws.push_message( + json.dumps( + { + "type": "Begin", + "id": "abc", + "expires_at": "2030-01-01T00:00:00", + } + ) + ) + + for _ in range(50): + if received: + break + await asyncio.sleep(0.01) + + assert len(received) == 1 + assert received[0].id == "abc" + + await client.disconnect() + + +async def test_async_handler_is_awaited(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + seen = [] + + async def on_begin(_client, event): + await asyncio.sleep(0) + seen.append(event.id) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Begin, on_begin) + await client.connect(_default_params()) + + fake_ws.push_message( + json.dumps( + {"type": "Begin", "id": "async-id", "expires_at": "2030-01-01T00:00:00"} + ) + ) + + for _ in range(50): + if seen: + break + await asyncio.sleep(0.01) + + assert seen == ["async-id"] + + await client.disconnect() + + +async def test_sync_and_async_handlers_can_mix(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + sync_seen = [] + async_seen = [] + + def sync_handler(_client, event): + sync_seen.append(event.id) + + async def async_handler(_client, event): + async_seen.append(event.id) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Begin, sync_handler) + client.on(StreamingEvents.Begin, async_handler) + await client.connect(_default_params()) + + fake_ws.push_message( + json.dumps({"type": "Begin", "id": "mix", "expires_at": "2030-01-01T00:00:00"}) + ) + + for _ in range(50): + if sync_seen and async_seen: + break + await asyncio.sleep(0.01) + + assert sync_seen == ["mix"] + assert async_seen == ["mix"] + + await client.disconnect() + + +async def test_error_event_then_close_fires_only_once( + mocker: MockFixture, caplog: pytest.LogCaptureFixture +): + caplog.set_level(logging.ERROR) + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + received = [] + + def on_error(_client, err): + received.append(err) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, on_error) + await client.connect(_default_params()) + + fake_ws.push_message( + json.dumps({"type": "Error", "error": "Invalid API key", "error_code": 4001}) + ) + fake_ws.push_close(ConnectionClosed(rcvd=Close(4001, "Not Authorized"), sent=None)) + + await _wait_for_tasks(client) + + assert len(received) == 1 + assert str(received[0]) == "Invalid API key" + assert received[0].code == 4001 + + error_logs = [ + rec + for rec in caplog.records + if "Streaming error" in rec.message and "4001" in rec.message + ] + close_logs = [ + rec + for rec in caplog.records + if "Connection closed" in rec.message and "4001" in rec.message + ] + assert len(error_logs) == 1 + assert len(close_logs) == 1 + + await client.disconnect() + + +async def test_clean_close_emits_no_error_or_log( + mocker: MockFixture, caplog: pytest.LogCaptureFixture +): + caplog.set_level(logging.ERROR) + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + received = [] + + def on_error(_client, err): + received.append(err) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, on_error) + await client.connect(_default_params()) + + fake_ws.push_close(ConnectionClosed(rcvd=Close(1000, "session ended"), sent=None)) + + await _wait_for_tasks(client) + + assert received == [] + error_logs = [rec for rec in caplog.records if rec.levelno >= logging.ERROR] + assert error_logs == [] + + +async def test_handler_exception_does_not_block_shutdown(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + def bad_handler(_client, _err): + raise RuntimeError("boom") + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, bad_handler) + await client.connect(_default_params()) + + fake_ws.push_close(ConnectionClosed(rcvd=Close(1011, "server error"), sent=None)) + + await _wait_for_tasks(client) + # If the handler exception had escaped, _wait_for_tasks would time out. + assert client._read_task.done() + + +async def test_invalid_status_during_connect_dispatches_error(mocker: MockFixture): + received = [] + + def on_error(_client, err): + received.append(err) + + response = type("R", (), {"status_code": 401})() + err = InvalidStatus(response=response) + + async def failing_connect(*_args, **_kwargs): + raise err + + mocker.patch( + "assemblyai.streaming.v3.async_client.websocket_connect_async", + new=failing_connect, + ) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, on_error) + + await client.connect(_default_params()) + + assert len(received) == 1 + assert received[0].code == 401 + assert "HTTP 401" in str(received[0]) + + +async def test_terminate_session_bypasses_stop_gate(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + # Pre-set stop, then queue a TerminateSession directly. The write loop must + # still send it before exiting. + client._stop_event.set() + await client._write_queue.put(TerminateSession()) + + for _ in range(100): + if fake_ws.send_call_count >= 1: + break + await asyncio.sleep(0.01) + + assert fake_ws.send_call_count >= 1 + assert any(isinstance(s, str) and "Terminate" in s for s in fake_ws.sent) + + await client.disconnect() + + +async def test_create_temporary_token(mocker: MockFixture): + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + + async def fake_get(self, url, params=None): + class R: + def raise_for_status(self_inner): + pass + + def json(self_inner): + return {"token": "tmp-tok"} + + return R() + + mocker.patch("httpx.AsyncClient.get", new=fake_get) + + token = await client.create_temporary_token( + expires_in_seconds=60, max_session_duration_seconds=600 + ) + assert token == "tmp-tok" + + # Clean up the (un-mocked) AsyncClient so the test doesn't emit + # "unclosed transport" warnings. + await client._client.aclose() + + +async def test_connect_twice_raises(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + with pytest.raises(RuntimeError, match="already connected"): + await client.connect(_default_params()) + + await client.disconnect() + + +async def test_disconnect_closes_http_client(mocker: MockFixture): + fake_ws = _FakeAsyncWebSocket() + _patch_connect(mocker, fake_ws) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + await client.connect(_default_params()) + + closed = [] + + async def fake_aclose(self): + closed.append(True) + + mocker.patch("httpx.AsyncClient.aclose", new=fake_aclose) + + await client.disconnect() + + assert closed == [True] + + +async def test_write_side_close_is_dispatched_when_read_short_circuits_on_stop( + mocker: MockFixture, caplog: pytest.LogCaptureFixture +): + """Regression: if the read task observes ``_stop_event`` at the top of its + loop (e.g. after processing a buffered message) before its next ``recv()`` + raises, the write task must still dispatch the connection-closed event. + Previously the write task only set stop and exited, so this close went + unreported.""" + caplog.set_level(logging.ERROR) + + close_exc = ConnectionClosed(rcvd=Close(1011, "send-side close"), sent=None) + fake_ws = _FakeAsyncWebSocket(send_raises=close_exc) + _patch_connect(mocker, fake_ws) + + received = [] + + def on_error(_client, err): + received.append(err) + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, on_error) + await client.connect(_default_params()) + + # Queue a write so the write task hits send() and raises ConnectionClosed. + await client.stream(b"\x00" * 32) + + # Wait for write task to finish dispatching the close. + for _ in range(200): + if received: + break + await asyncio.sleep(0.01) + + assert ( + len(received) == 1 + ), f"expected exactly one on_error from write-side close, got {received}" + assert received[0].code == 1011 + + await client.disconnect() diff --git a/tox.ini b/tox.ini index 3bfddd2..98ded94 100644 --- a/tox.ini +++ b/tox.ini @@ -27,6 +27,7 @@ deps = pytest-xdist pytest-mock pytest-cov + pytest-asyncio factory-boy allowlist_externals = pytest