diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index defd4428b4..8998a7db99 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -46,6 +46,17 @@ """The audio format for realtime audio streams.""" +class RealtimeCustomVoice(TypedDict): + """A custom Realtime voice object.""" + + id: str + """The custom voice ID.""" + + +RealtimeVoice: TypeAlias = str | RealtimeCustomVoice | Mapping[str, Any] +"""The voice to use for realtime audio output.""" + + RealtimeReasoningEffort: TypeAlias = Literal["minimal", "low", "medium", "high", "xhigh"] | str """The reasoning effort for realtime model responses.""" @@ -124,7 +135,7 @@ class RealtimeAudioOutputConfig(TypedDict, total=False): """Configuration for audio output in realtime sessions.""" format: RealtimeAudioFormat | OpenAIRealtimeAudioFormats - voice: str + voice: RealtimeVoice speed: float @@ -163,7 +174,7 @@ class RealtimeSessionModelSettings(TypedDict): audio: NotRequired[RealtimeAudioConfig] """The audio configuration for the session.""" - voice: NotRequired[str] + voice: NotRequired[RealtimeVoice] """The voice to use for audio output.""" speed: NotRequired[float] diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 6b986c6edc..b8a47ed89f 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -370,6 +370,30 @@ def get_server_event_type_adapter() -> TypeAdapter[AllRealtimeServerEvents]: return ServerEventTypeAdapter +def _normalize_custom_voice_for_server_event_validation(value: Any) -> Any: + if isinstance(value, list): + return [_normalize_custom_voice_for_server_event_validation(item) for item in value] + + if not isinstance(value, dict): + return value + + normalized: dict[str, Any] = {} + for key, item in value.items(): + if key == "voice" and isinstance(item, Mapping): + voice_id = item.get("id") + if isinstance(voice_id, str): + normalized[key] = voice_id + continue + normalized[key] = _normalize_custom_voice_for_server_event_validation(item) + return normalized + + +def _create_realtime_audio_output(audio_output_args: dict[str, Any]) -> Any: + return cast(Any, OpenAIRealtimeAudioOutput).model_construct( + _fields_set=set(audio_output_args), **audio_output_args + ) + + async def _collect_enabled_handoffs( agent: RealtimeAgent[Any], context_wrapper: RunContextWrapper[Any] ) -> list[Handoff[Any, RealtimeAgent[Any]]]: @@ -1054,7 +1078,10 @@ async def _handle_ws_event(self, event: dict[str, Any]): try: if "previous_item_id" in event and event["previous_item_id"] is None: event["previous_item_id"] = "" # TODO (rm) remove - parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python(event) + validation_event = _normalize_custom_voice_for_server_event_validation(event) + parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python( + validation_event + ) except pydantic.ValidationError as e: logger.error(f"Failed to validate server event: {event}", exc_info=True) await self._emit_event(RealtimeModelErrorEvent(error=e)) @@ -1447,7 +1474,7 @@ def _get_session_config( "output_modalities": output_modalities, "audio": OpenAIRealtimeAudioConfig( input=OpenAIRealtimeAudioInput(**audio_input_args), - output=OpenAIRealtimeAudioOutput(**audio_output_args), + output=_create_realtime_audio_output(audio_output_args), ), "tools": self._tools_to_session_tools( tools=model_settings.get("tools", []), diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 4ebc2aa9a3..95e42f97b8 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -7,10 +7,12 @@ import pytest import websockets +from pydantic import BaseModel, TypeAdapter from agents import Agent, function_tool from agents.exceptions import UserError from agents.handoffs import handoff +from agents.realtime import openai_realtime as openai_realtime_module from agents.realtime.model import RealtimeModelConfig from agents.realtime.model_events import ( RealtimeModelAudioEvent, @@ -445,6 +447,80 @@ async def test_handle_invalid_event_schema_logs_error(self, model): error_event = mock_listener.on_event.call_args_list[1][0][0] assert error_event.type == "error" + @pytest.mark.asyncio + async def test_custom_voice_response_events_update_response_sequencer(self, model, monkeypatch): + """Dict-shaped custom voices should not block response.create sequencing.""" + payload_types: list[str] = [] + + async def fake_send_raw(event): + payload_types.append(event.type) + + class CustomVoiceRejectingAdapter: + _string_adapter = TypeAdapter(str) + + def validate_python(self, event): + voice = event.get("response", {}).get("audio", {}).get("output", {}).get("voice") + if isinstance(voice, dict): + self._string_adapter.validate_python(voice) + return SimpleNamespace(type=event["type"]) + + monkeypatch.setattr(model, "_send_raw_message", fake_send_raw) + model._server_event_type_adapter = CustomVoiceRejectingAdapter() + mock_listener = AsyncMock() + model.add_listener(mock_listener) + + await model._send_user_input(RealtimeModelSendUserInput(user_input="hi")) + await asyncio.sleep(0) + + assert payload_types == ["conversation.item.create", "response.create"] + assert model._response_control == "create_requested" + + response_with_custom_voice = { + "type": "response.created", + "response": {"audio": {"output": {"voice": {"id": "voice_test"}}}}, + } + await model._handle_ws_event(response_with_custom_voice) + + assert model._ongoing_response is True + assert model._response_control == "free" + + await model._handle_ws_event( + { + "type": "response.done", + "response": {"audio": {"output": {"voice": {"id": "voice_test"}}}}, + } + ) + + assert model._ongoing_response is False + assert model._response_control == "free" + raw_event = mock_listener.on_event.call_args_list[0][0][0] + assert raw_event.data is response_with_custom_voice + assert response_with_custom_voice["response"]["audio"]["output"]["voice"] == { + "id": "voice_test" + } + + await model._send_tool_output( + RealtimeModelSendToolOutput( + tool_call=SimpleNamespace( + id="item_1", + previous_item_id=None, + call_id="call_1", + arguments="{}", + name="lookup", + ), + output="ok", + start_response=True, + ) + ) + await asyncio.sleep(0) + + assert payload_types == [ + "conversation.item.create", + "response.create", + "conversation.item.create", + "response.create", + ] + @pytest.mark.asyncio async def test_handle_unknown_event_type_ignored(self, model): """Test that unknown event types are ignored gracefully.""" @@ -1519,6 +1595,39 @@ def test_get_and_update_session_config(self, model): assert cfg.audio is not None and cfg.audio.output is not None assert cfg.audio.output.voice == "verse" + def test_session_config_accepts_custom_voice_object(self, model): + custom_voice = {"id": "voice_test"} + + cfg = model._get_session_config({"voice": custom_voice}) + payload = cfg.model_dump(exclude_unset=True) + + assert payload["audio"]["output"]["voice"] == custom_voice + + def test_session_config_accepts_nested_custom_voice_object(self, model): + custom_voice = {"id": "voice_test"} + + cfg = model._get_session_config({"audio": {"output": {"voice": custom_voice}}}) + payload = cfg.model_dump(exclude_unset=True) + + assert payload["audio"]["output"]["voice"] == custom_voice + + def test_audio_output_accepts_custom_voice_with_older_generated_model(self, monkeypatch): + class LegacyRealtimeAudioOutput(BaseModel): + voice: str | None = None + format: Any = None + speed: float | None = None + + monkeypatch.setattr( + openai_realtime_module, "OpenAIRealtimeAudioOutput", LegacyRealtimeAudioOutput + ) + custom_voice = {"id": "voice_test"} + + output = openai_realtime_module._create_realtime_audio_output({"voice": custom_voice}) + payload = output.model_dump(exclude_unset=True) + + assert output.voice == custom_voice + assert payload["voice"] == custom_voice + def test_session_config_defaults_audio_formats_when_not_call(self, model): settings: dict[str, Any] = {} cfg = model._get_session_config(settings) diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index e289bc3c9e..03148c739a 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -1386,6 +1386,42 @@ async def test_handoff_tool_handling(self, mock_model): # Verify agent was updated assert session._current_agent == second_agent + @pytest.mark.asyncio + async def test_handoff_session_update_preserves_custom_voice(self, mock_model): + custom_voice = {"id": "voice_test"} + first_agent = RealtimeAgent( + name="first_agent", + instructions="first_agent_instructions", + tools=[], + handoffs=[], + ) + second_agent = RealtimeAgent( + name="second_agent", + instructions="second_agent_instructions", + tools=[], + handoffs=[], + ) + first_agent.handoffs = [second_agent] + session = RealtimeSession( + mock_model, + first_agent, + None, + model_config={"initial_model_settings": {"voice": custom_voice}}, + ) + + await session._handle_tool_call( + RealtimeModelToolCallEvent( + name=Handoff.default_tool_name(second_agent), + call_id="call_789", + arguments="{}", + ) + ) + + session_update_event = mock_model.sent_events[0] + assert isinstance(session_update_event, RealtimeModelSendSessionUpdate) + assert session_update_event.session_settings["voice"] == custom_voice + assert mock_model.sent_events[1].start_response is True + @pytest.mark.asyncio async def test_unknown_tool_handling(self, mock_model, mock_agent, mock_function_tool): """Test that unknown tools complete the model call without starting a response."""