From 0b7468fd7a5f42cf095c3eff90fe387d84473298 Mon Sep 17 00:00:00 2001 From: AssemblyAI Date: Tue, 12 May 2026 02:11:47 -0600 Subject: [PATCH] Project import generated by Copybara. GitOrigin-RevId: 4499599d2a8f9f6f0e0c7c366fce54edfbd802a8 --- assemblyai/__version__.py | 2 +- assemblyai/streaming/v3/client.py | 22 +++- assemblyai/streaming/v3/models.py | 1 + tests/unit/test_streaming.py | 193 ++++++++++++++++++++++++++++-- 4 files changed, 204 insertions(+), 14 deletions(-) diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py index d720aed..ea2be9a 100644 --- a/assemblyai/__version__.py +++ b/assemblyai/__version__.py @@ -1 +1 @@ -__version__ = "0.64.0" +__version__ = "0.64.1" diff --git a/assemblyai/streaming/v3/client.py b/assemblyai/streaming/v3/client.py index a9d2581..cce42d8 100644 --- a/assemblyai/streaming/v3/client.py +++ b/assemblyai/streaming/v3/client.py @@ -206,7 +206,10 @@ def connect(self, params: StreamingParameters) -> None: logger.debug("Connected to WebSocket server") def disconnect(self, terminate: bool = False) -> None: - if terminate and not self._stop_event.is_set(): + # Enqueue Terminate even when stop is already set: `_write_message` + # bypasses the stop gate for TerminateSession so the frame still + # reaches the server when the write thread is alive. + if terminate: self._write_queue.put(TerminateSession()) self._stop_event.set() @@ -341,7 +344,10 @@ def _handle_message(self, message: EventMessage) -> None: event_type = StreamingEvents[message.type] for handler in self._handlers[event_type]: - handler(self, message) + try: + handler(self, message) + except Exception: + logger.exception("on_%s handler raised", event_type.name.lower()) def _parse_message(self, data: Dict[str, Any]) -> Optional[EventMessage]: if "type" in data: @@ -385,7 +391,10 @@ def _handle_warning(self, warning: WarningEvent): "Streaming warning (code=%s): %s", warning.warning_code, warning.warning ) for handler in self._handlers[StreamingEvents.Warning]: - handler(self, warning) + try: + handler(self, warning) + except Exception: + logger.exception("on_warning handler raised") def _report_server_error(self, error: ErrorEvent) -> None: self._server_error_reported = True @@ -395,6 +404,13 @@ def _report_server_error(self, error: ErrorEvent) -> None: ) logger.error("Streaming error: %s (code=%s)", error.error, error.error_code) self._dispatch_error(streaming_error) + # Tear down locally so a server that sends Error without a trailing + # close frame doesn't leave the read loop spinning in recv(timeout=1) + # forever. `_close_websocket` is idempotent; if the trailing close + # does arrive, `_report_connection_closed` will dedup via + # `_server_error_reported`. + self._close_websocket() + self._stop_event.set() def _report_connection_closed( self, diff --git a/assemblyai/streaming/v3/models.py b/assemblyai/streaming/v3/models.py index 4115f9f..5b35a1c 100644 --- a/assemblyai/streaming/v3/models.py +++ b/assemblyai/streaming/v3/models.py @@ -105,6 +105,7 @@ class StreamingSessionParameters(BaseModel): keyterms_prompt: Optional[List[str]] = None filter_profanity: Optional[bool] = None prompt: Optional[str] = None + interruption_delay: Optional[int] = None class Encoding(str, Enum): diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py index 8151899..eb40850 100644 --- a/tests/unit/test_streaming.py +++ b/tests/unit/test_streaming.py @@ -650,6 +650,37 @@ def mocked_websocket_connect( assert "continuous_partials=True" in actual_url +def test_client_connect_with_interruption_delay(mocker: MockFixture): + # Given: client + interruption_delay=500 (U3-Pro early-partial override) + actual_url = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_url + actual_url = url + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + _disable_rw_threads(mocker) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + params = StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.u3_rt_pro, + interruption_delay=500, + ) + + # When: connect + client.connect(params) + + # Then: parameter reaches the URL + assert "interruption_delay=500" in actual_url + + def test_customer_support_audio_capture_warns_when_enabled( mocker: MockFixture, caplog: pytest.LogCaptureFixture ): @@ -986,7 +1017,10 @@ def on_error(self_, err): seed_chunks=[b"\x00" * 320] * 50, ) - # Then: exactly one on_error with the rich server-error content. + # Then: exactly one on_error with the rich server-error content. The + # local websocket has been closed (by _report_server_error). Whether the + # trailing close-frame race produces an additional "Connection closed" + # log depends on scheduling, but dedup ensures no second on_error fires. assert len(received) == 1, ( f"expected exactly 1 error, got {len(received)}: {received}" ) @@ -1001,19 +1035,10 @@ def on_error(self_, err): for rec in caplog.records if "Streaming error" in rec.message and "4001" in rec.message ] - close_logs = [ - rec - for rec in caplog.records - if "Connection closed" in rec.message and "4001" in rec.message - ] assert len(error_logs) == 1, ( f"expected exactly 1 Streaming-error log, got {len(error_logs)}" ) assert error_logs[0].levelno == logging.ERROR - assert len(close_logs) == 1, ( - f"expected exactly 1 Connection-closed log, got {len(close_logs)}" - ) - assert close_logs[0].levelno == logging.ERROR client.disconnect(terminate=True) @@ -1204,3 +1229,151 @@ def test_write_thread_close_is_drained_by_read_thread(mocker: MockFixture): assert received[0].code == 1011 client.disconnect() + + +def test_server_error_without_trailing_close_exits_read_loop(mocker: MockFixture): + # Given: server sends an Error frame and then nothing (no close). Without + # _report_server_error setting _stop_event, the read loop would call + # recv(timeout=1) forever after dispatching the error. + error_json = json.dumps( + {"type": "Error", "error": "Server boom", "error_code": 5001} + ) + fake_ws = _FakeWebSocket(recv_script=[error_json]) + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + return_value=fake_ws, + ) + received = [] + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Error, lambda c, e: received.append(e)) + + # When: connect and let the read thread dispatch the Error + _connect_and_wait(client, _default_params()) + + # Then: error was dispatched once and the read thread exited despite the + # absence of a trailing close frame. + assert len(received) == 1 + assert received[0].code == 5001 + assert client._stop_event.is_set() + assert not client._read_thread.is_alive() + assert not client._write_thread.is_alive() + + client.disconnect(terminate=True) + + +def test_disconnect_terminate_enqueues_when_stop_already_set(mocker: MockFixture): + # Given: a client whose _stop_event is already set (e.g. after a server + # error invoked _report_server_error). Threads were never started, so the + # only observable side-effect of disconnect(terminate=True) is the queue. + fake_ws = _FakeWebSocket(recv_script=[]) + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + return_value=fake_ws, + ) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client._websocket = fake_ws + client._stop_event.set() + + # When: disconnect(terminate=True) runs after stop is already set + client.disconnect(terminate=True) + + # Then: TerminateSession was enqueued unconditionally; the disconnect-side + # guard no longer silently swallows the terminate intent. + assert client._write_queue.qsize() == 1 + msg = client._write_queue.get_nowait() + assert isinstance(msg, TerminateSession) + + +def test_message_handler_exception_does_not_kill_read_thread(mocker: MockFixture): + # Given: a Turn handler that raises, followed by a Termination event. If + # the exception escapes _handle_message, the read thread dies before + # processing the Termination event. + turn_json = json.dumps( + { + "type": "Turn", + "turn_order": 1, + "turn_is_formatted": True, + "end_of_turn": True, + "transcript": "hi", + "end_of_turn_confidence": 0.9, + "words": [], + } + ) + termination_json = json.dumps( + { + "type": "Termination", + "audio_duration_seconds": 1, + "session_duration_seconds": 1, + } + ) + fake_ws = _FakeWebSocket(recv_script=[turn_json, termination_json]) + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + return_value=fake_ws, + ) + turns = [] + terminations = [] + + def bad_turn_handler(self_, msg): + turns.append(msg) + raise RuntimeError("boom") + + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Turn, bad_turn_handler) + client.on(StreamingEvents.Termination, lambda c, e: terminations.append(e)) + + # When: connect; the read thread processes the Turn (handler raises) then + # the Termination (which sets _stop_event and exits the loop) + _connect_and_wait(client, _default_params()) + + # Then: read thread survived the raising handler and processed Termination. + assert len(turns) == 1 + assert len(terminations) == 1 + assert client._stop_event.is_set() + assert not client._read_thread.is_alive() + + client.disconnect() + + +def test_warning_handler_exception_does_not_kill_read_thread(mocker: MockFixture): + # Given: a Warning handler that raises, followed by a clean close. + warning_json = json.dumps( + {"type": "Warning", "warning": "session ending soon", "warning_code": 1234} + ) + clean_close = ConnectionClosed(rcvd=Close(1000, "session ended"), sent=None) + fake_ws = _FakeWebSocket(recv_script=[warning_json, clean_close]) + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + return_value=fake_ws, + ) + warnings_received = [] + errors_received = [] + + def bad_warning_handler(self_, w): + warnings_received.append(w) + raise RuntimeError("boom") + + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Warning, bad_warning_handler) + client.on(StreamingEvents.Error, lambda c, e: errors_received.append(e)) + + # When: connect; the read thread processes the warning (handler raises) + # then the clean close + _connect_and_wait(client, _default_params()) + + # Then: warning was delivered, read thread survived, clean close completed + # without dispatching an error. + assert len(warnings_received) == 1 + assert errors_received == [] + assert client._stop_event.is_set() + assert not client._read_thread.is_alive() + + client.disconnect()