Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/agents/realtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@
"""The audio format for realtime audio streams."""


class RealtimeCustomVoice(TypedDict):
"""A custom Realtime voice object."""

id: str
"""The custom voice ID."""


RealtimeVoice: TypeAlias = str | RealtimeCustomVoice | Mapping[str, Any]
"""The voice to use for realtime audio output."""


RealtimeReasoningEffort: TypeAlias = Literal["minimal", "low", "medium", "high", "xhigh"] | str
"""The reasoning effort for realtime model responses."""

Expand Down Expand Up @@ -124,7 +135,7 @@ class RealtimeAudioOutputConfig(TypedDict, total=False):
"""Configuration for audio output in realtime sessions."""

format: RealtimeAudioFormat | OpenAIRealtimeAudioFormats
voice: str
voice: RealtimeVoice
speed: float


Expand Down Expand Up @@ -163,7 +174,7 @@ class RealtimeSessionModelSettings(TypedDict):
audio: NotRequired[RealtimeAudioConfig]
"""The audio configuration for the session."""

voice: NotRequired[str]
voice: NotRequired[RealtimeVoice]
"""The voice to use for audio output."""

speed: NotRequired[float]
Expand Down
31 changes: 29 additions & 2 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,30 @@ def get_server_event_type_adapter() -> TypeAdapter[AllRealtimeServerEvents]:
return ServerEventTypeAdapter


def _normalize_custom_voice_for_server_event_validation(value: Any) -> Any:
if isinstance(value, list):
return [_normalize_custom_voice_for_server_event_validation(item) for item in value]

if not isinstance(value, dict):
return value

normalized: dict[str, Any] = {}
for key, item in value.items():
if key == "voice" and isinstance(item, Mapping):
voice_id = item.get("id")
if isinstance(voice_id, str):
normalized[key] = voice_id
continue
normalized[key] = _normalize_custom_voice_for_server_event_validation(item)
return normalized


def _create_realtime_audio_output(audio_output_args: dict[str, Any]) -> Any:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we upgrade openai package to openai>=2.36.0 , this workaround is not necessary while _normalize_custom_voice_for_server_event_validation is still required even with the latest version.

Can you add quick TODO comments explaining why and when to remove to these internal workarounds?

return cast(Any, OpenAIRealtimeAudioOutput).model_construct(
_fields_set=set(audio_output_args), **audio_output_args
)


async def _collect_enabled_handoffs(
agent: RealtimeAgent[Any], context_wrapper: RunContextWrapper[Any]
) -> list[Handoff[Any, RealtimeAgent[Any]]]:
Expand Down Expand Up @@ -1054,7 +1078,10 @@ async def _handle_ws_event(self, event: dict[str, Any]):
try:
if "previous_item_id" in event and event["previous_item_id"] is None:
event["previous_item_id"] = "" # TODO (rm) remove
parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python(event)
validation_event = _normalize_custom_voice_for_server_event_validation(event)
parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python(
validation_event
)
except pydantic.ValidationError as e:
logger.error(f"Failed to validate server event: {event}", exc_info=True)
await self._emit_event(RealtimeModelErrorEvent(error=e))
Expand Down Expand Up @@ -1447,7 +1474,7 @@ def _get_session_config(
"output_modalities": output_modalities,
"audio": OpenAIRealtimeAudioConfig(
input=OpenAIRealtimeAudioInput(**audio_input_args),
output=OpenAIRealtimeAudioOutput(**audio_output_args),
output=_create_realtime_audio_output(audio_output_args),
),
"tools": self._tools_to_session_tools(
tools=model_settings.get("tools", []),
Expand Down
109 changes: 109 additions & 0 deletions tests/realtime/test_openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import pytest
import websockets
from pydantic import BaseModel, TypeAdapter

from agents import Agent, function_tool
from agents.exceptions import UserError
from agents.handoffs import handoff
from agents.realtime import openai_realtime as openai_realtime_module
from agents.realtime.model import RealtimeModelConfig
from agents.realtime.model_events import (
RealtimeModelAudioEvent,
Expand Down Expand Up @@ -445,6 +447,80 @@ async def test_handle_invalid_event_schema_logs_error(self, model):
error_event = mock_listener.on_event.call_args_list[1][0][0]
assert error_event.type == "error"

@pytest.mark.asyncio
async def test_custom_voice_response_events_update_response_sequencer(self, model, monkeypatch):
"""Dict-shaped custom voices should not block response.create sequencing."""
payload_types: list[str] = []

async def fake_send_raw(event):
payload_types.append(event.type)

class CustomVoiceRejectingAdapter:
_string_adapter = TypeAdapter(str)

def validate_python(self, event):
voice = event.get("response", {}).get("audio", {}).get("output", {}).get("voice")
if isinstance(voice, dict):
self._string_adapter.validate_python(voice)
return SimpleNamespace(type=event["type"])

monkeypatch.setattr(model, "_send_raw_message", fake_send_raw)
model._server_event_type_adapter = CustomVoiceRejectingAdapter()
mock_listener = AsyncMock()
model.add_listener(mock_listener)

await model._send_user_input(RealtimeModelSendUserInput(user_input="hi"))
await asyncio.sleep(0)

assert payload_types == ["conversation.item.create", "response.create"]
assert model._response_control == "create_requested"

response_with_custom_voice = {
"type": "response.created",
"response": {"audio": {"output": {"voice": {"id": "voice_test"}}}},
}
await model._handle_ws_event(response_with_custom_voice)

assert model._ongoing_response is True
assert model._response_control == "free"

await model._handle_ws_event(
{
"type": "response.done",
"response": {"audio": {"output": {"voice": {"id": "voice_test"}}}},
}
)

assert model._ongoing_response is False
assert model._response_control == "free"
raw_event = mock_listener.on_event.call_args_list[0][0][0]
assert raw_event.data is response_with_custom_voice
assert response_with_custom_voice["response"]["audio"]["output"]["voice"] == {
"id": "voice_test"
}

await model._send_tool_output(
RealtimeModelSendToolOutput(
tool_call=SimpleNamespace(
id="item_1",
previous_item_id=None,
call_id="call_1",
arguments="{}",
name="lookup",
),
output="ok",
start_response=True,
)
)
await asyncio.sleep(0)

assert payload_types == [
"conversation.item.create",
"response.create",
"conversation.item.create",
"response.create",
]

@pytest.mark.asyncio
async def test_handle_unknown_event_type_ignored(self, model):
"""Test that unknown event types are ignored gracefully."""
Expand Down Expand Up @@ -1519,6 +1595,39 @@ def test_get_and_update_session_config(self, model):
assert cfg.audio is not None and cfg.audio.output is not None
assert cfg.audio.output.voice == "verse"

def test_session_config_accepts_custom_voice_object(self, model):
custom_voice = {"id": "voice_test"}

cfg = model._get_session_config({"voice": custom_voice})
payload = cfg.model_dump(exclude_unset=True)

assert payload["audio"]["output"]["voice"] == custom_voice

def test_session_config_accepts_nested_custom_voice_object(self, model):
custom_voice = {"id": "voice_test"}

cfg = model._get_session_config({"audio": {"output": {"voice": custom_voice}}})
payload = cfg.model_dump(exclude_unset=True)

assert payload["audio"]["output"]["voice"] == custom_voice

def test_audio_output_accepts_custom_voice_with_older_generated_model(self, monkeypatch):
class LegacyRealtimeAudioOutput(BaseModel):
voice: str | None = None
format: Any = None
speed: float | None = None

monkeypatch.setattr(
openai_realtime_module, "OpenAIRealtimeAudioOutput", LegacyRealtimeAudioOutput
)
custom_voice = {"id": "voice_test"}

output = openai_realtime_module._create_realtime_audio_output({"voice": custom_voice})
payload = output.model_dump(exclude_unset=True)

assert output.voice == custom_voice
assert payload["voice"] == custom_voice

def test_session_config_defaults_audio_formats_when_not_call(self, model):
settings: dict[str, Any] = {}
cfg = model._get_session_config(settings)
Expand Down
36 changes: 36 additions & 0 deletions tests/realtime/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,42 @@ async def test_handoff_tool_handling(self, mock_model):
# Verify agent was updated
assert session._current_agent == second_agent

@pytest.mark.asyncio
async def test_handoff_session_update_preserves_custom_voice(self, mock_model):
custom_voice = {"id": "voice_test"}
first_agent = RealtimeAgent(
name="first_agent",
instructions="first_agent_instructions",
tools=[],
handoffs=[],
)
second_agent = RealtimeAgent(
name="second_agent",
instructions="second_agent_instructions",
tools=[],
handoffs=[],
)
first_agent.handoffs = [second_agent]
session = RealtimeSession(
mock_model,
first_agent,
None,
model_config={"initial_model_settings": {"voice": custom_voice}},
)

await session._handle_tool_call(
RealtimeModelToolCallEvent(
name=Handoff.default_tool_name(second_agent),
call_id="call_789",
arguments="{}",
)
)

session_update_event = mock_model.sent_events[0]
assert isinstance(session_update_event, RealtimeModelSendSessionUpdate)
assert session_update_event.session_settings["voice"] == custom_voice
assert mock_model.sent_events[1].start_response is True

@pytest.mark.asyncio
async def test_unknown_tool_handling(self, mock_model, mock_agent, mock_function_tool):
"""Test that unknown tools complete the model call without starting a response."""
Expand Down
Loading