Skip to content

Commit b805bfc

Browse files
committed
feat: allow custom-client for OpenAIModel and GeminiModel
1 parent 4342fda commit b805bfc

File tree

4 files changed

+263
-12
lines changed

4 files changed

+263
-12
lines changed

src/strands/models/gemini.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,41 @@ class GeminiConfig(TypedDict, total=False):
4848
def __init__(
4949
self,
5050
*,
51+
client: Optional[genai.Client] = None,
5152
client_args: Optional[dict[str, Any]] = None,
5253
**model_config: Unpack[GeminiConfig],
5354
) -> None:
5455
"""Initialize provider instance.
5556
5657
Args:
58+
client: Pre-configured Gemini client to reuse across requests.
59+
When provided, this client will be reused for all requests and will NOT be closed
60+
by the model. The caller is responsible for managing the client lifecycle.
61+
This is useful for:
62+
- Injecting custom client wrappers
63+
- Reusing connection pools within a single event loop/worker
64+
- Centralizing observability, retries, and networking policy
65+
Note: The client should not be shared across different asyncio event loops.
5766
client_args: Arguments for the underlying Gemini client (e.g., api_key).
5867
For a complete list of supported arguments, see https://googleapis.github.io/python-genai/.
68+
Note: If `client` is provided, this parameter is ignored.
5969
**model_config: Configuration options for the Gemini model.
70+
71+
Raises:
72+
ValueError: If both `client` and `client_args` are provided.
6073
"""
6174
validate_config_keys(model_config, GeminiModel.GeminiConfig)
6275
self.config = GeminiModel.GeminiConfig(**model_config)
6376

64-
logger.debug("config=<%s> | initializing", self.config)
77+
# Validate that only one client configuration method is provided
78+
if client is not None and client_args is not None and len(client_args) > 0:
79+
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")
6580

81+
self._injected_client = client
6682
self.client_args = client_args or {}
6783

84+
logger.debug("config=<%s> | initializing", self.config)
85+
6886
@override
6987
def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: ignore[override]
7088
"""Update the Gemini model configuration with the provided arguments.
@@ -83,6 +101,26 @@ def get_config(self) -> GeminiConfig:
83101
"""
84102
return self.config
85103

104+
def _get_client(self) -> genai.Client:
105+
"""Get a Gemini client for making requests.
106+
107+
This method handles client lifecycle management:
108+
- If an injected client was provided during initialization, it returns that client
109+
without managing its lifecycle (caller is responsible for cleanup).
110+
- Otherwise, creates a new genai.Client from client_args. Note that unlike OpenAI,
111+
Gemini clients don't require explicit cleanup, so the created client can be used
112+
directly without context manager wrapping.
113+
114+
Returns:
115+
genai.Client: A Gemini client instance.
116+
"""
117+
if self._injected_client is not None:
118+
# Use the injected client (caller manages lifecycle)
119+
return self._injected_client
120+
else:
121+
# Create a new client from client_args
122+
return genai.Client(**self.client_args)
123+
86124
def _format_request_content_part(self, content: ContentBlock) -> genai.types.Part:
87125
"""Format content block into a Gemini part instance.
88126
@@ -365,9 +403,10 @@ async def stream(
365403
"""
366404
request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params"))
367405

368-
client = genai.Client(**self.client_args).aio
406+
client_aio = self._get_client().aio
407+
369408
try:
370-
response = await client.models.generate_content_stream(**request)
409+
response = await client_aio.models.generate_content_stream(**request)
371410

372411
yield self._format_chunk({"chunk_type": "message_start"})
373412
yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"})
@@ -448,6 +487,14 @@ async def structured_output(
448487
"response_schema": output_model.model_json_schema(),
449488
}
450489
request = self._format_request(prompt, None, system_prompt, params)
451-
client = genai.Client(**self.client_args).aio
452-
response = await client.models.generate_content(**request)
490+
491+
# Determine which client to use based on configuration
492+
if self._injected_client is not None:
493+
# Use the injected client (caller manages lifecycle)
494+
client_aio = self._injected_client.aio
495+
else:
496+
# Create a new client from client_args
497+
client_aio = genai.Client(**self.client_args).aio
498+
499+
response = await client_aio.models.generate_content(**request)
453500
yield {"output": output_model.model_validate(response.parsed)}

src/strands/models/openai.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import json
88
import logging
99
import mimetypes
10-
from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast
10+
from contextlib import asynccontextmanager
11+
from typing import Any, AsyncGenerator, AsyncIterator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast
1112

1213
import openai
1314
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
@@ -55,16 +56,40 @@ class OpenAIConfig(TypedDict, total=False):
5556
model_id: str
5657
params: Optional[dict[str, Any]]
5758

58-
def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None:
59+
def __init__(
60+
self,
61+
client: Optional[Client] = None,
62+
client_args: Optional[dict[str, Any]] = None,
63+
**model_config: Unpack[OpenAIConfig],
64+
) -> None:
5965
"""Initialize provider instance.
6066
6167
Args:
62-
client_args: Arguments for the OpenAI client.
68+
client: Pre-configured OpenAI-compatible client to reuse across requests.
69+
When provided, this client will be reused for all requests and will NOT be closed
70+
by the model. The caller is responsible for managing the client lifecycle.
71+
This is useful for:
72+
- Injecting custom client wrappers (e.g., GuardrailsAsyncOpenAI)
73+
- Reusing connection pools within a single event loop/worker
74+
- Centralizing observability, retries, and networking policy
75+
- Pointing to custom model gateways
76+
Note: The client should not be shared across different asyncio event loops.
77+
client_args: Arguments for the OpenAI client (legacy approach).
6378
For a complete list of supported arguments, see https://pypi.org/project/openai/.
79+
Note: If `client` is provided, this parameter is ignored.
6480
**model_config: Configuration options for the OpenAI model.
81+
82+
Raises:
83+
ValueError: If both `client` and `client_args` are provided.
6584
"""
6685
validate_config_keys(model_config, self.OpenAIConfig)
6786
self.config = dict(model_config)
87+
88+
# Validate that only one client configuration method is provided
89+
if client is not None and client_args is not None and len(client_args) > 0:
90+
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")
91+
92+
self._injected_client = client
6893
self.client_args = client_args or {}
6994

7095
logger.debug("config=<%s> | initializing", self.config)
@@ -422,6 +447,34 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
422447
case _:
423448
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
424449

450+
@asynccontextmanager
451+
async def _get_client(self) -> AsyncIterator[Client]:
452+
"""Get an OpenAI client for making requests.
453+
454+
This context manager handles client lifecycle management:
455+
- If an injected client was provided during initialization, it yields that client
456+
without closing it (caller manages lifecycle).
457+
- Otherwise, creates a new AsyncOpenAI client from client_args and automatically
458+
closes it when the context exits.
459+
460+
Note: We create a new client per request to avoid connection sharing in the underlying
461+
httpx client, as the asyncio event loop does not allow connections to be shared.
462+
For more details, see https://github.com/encode/httpx/discussions/2959.
463+
464+
Yields:
465+
Client: An OpenAI-compatible client instance.
466+
"""
467+
if self._injected_client is not None:
468+
# Use the injected client (caller manages lifecycle)
469+
yield self._injected_client
470+
else:
471+
# Create a new client from client_args
472+
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
473+
# httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
474+
# refer to https://github.com/encode/httpx/discussions/2959.
475+
async with openai.AsyncOpenAI(**self.client_args) as client:
476+
yield client
477+
425478
@override
426479
async def stream(
427480
self,
@@ -457,7 +510,7 @@ async def stream(
457510
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
458511
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
459512
# https://github.com/encode/httpx/discussions/2959.
460-
async with openai.AsyncOpenAI(**self.client_args) as client:
513+
async with self._get_client() as client:
461514
try:
462515
response = await client.chat.completions.create(**request)
463516
except openai.BadRequestError as e:
@@ -576,9 +629,9 @@ async def structured_output(
576629
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
577630
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
578631
# https://github.com/encode/httpx/discussions/2959.
579-
async with openai.AsyncOpenAI(**self.client_args) as client:
632+
async with self._get_client() as client:
580633
try:
581-
response: ParsedChatCompletion = await client.beta.chat.completions.parse(
634+
response: ParsedChatCompletion = await client.beta.chat.completions.parse( # type: ignore[attr-defined]
582635
model=self.get_config()["model_id"],
583636
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
584637
response_format=output_model,

tests/strands/models/test_gemini.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,3 +637,77 @@ async def test_stream_handles_non_json_error(gemini_client, model, messages, cap
637637

638638
assert "Gemini API returned non-JSON error" in caplog.text
639639
assert f"error_message=<{error_message}>" in caplog.text
640+
641+
642+
@pytest.mark.asyncio
643+
async def test_stream_with_injected_client(model_id, agenerator, alist):
644+
"""Test that stream works with an injected client and doesn't close it."""
645+
# Create a mock injected client
646+
mock_injected_client = unittest.mock.Mock()
647+
mock_injected_client.aio = unittest.mock.AsyncMock()
648+
649+
mock_injected_client.aio.models.generate_content_stream.return_value = agenerator(
650+
[
651+
genai.types.GenerateContentResponse(
652+
candidates=[
653+
genai.types.Candidate(
654+
content=genai.types.Content(
655+
parts=[genai.types.Part(text="Hello")],
656+
),
657+
finish_reason="STOP",
658+
),
659+
],
660+
usage_metadata=genai.types.GenerateContentResponseUsageMetadata(
661+
prompt_token_count=1,
662+
total_token_count=3,
663+
),
664+
),
665+
]
666+
)
667+
668+
# Create model with injected client
669+
model = GeminiModel(client=mock_injected_client, model_id=model_id)
670+
671+
messages = [{"role": "user", "content": [{"text": "test"}]}]
672+
response = model.stream(messages)
673+
tru_events = await alist(response)
674+
675+
# Verify events were generated
676+
assert len(tru_events) > 0
677+
678+
# Verify the injected client was used
679+
mock_injected_client.aio.models.generate_content_stream.assert_called_once()
680+
681+
682+
@pytest.mark.asyncio
683+
async def test_structured_output_with_injected_client(model_id, weather_output, alist):
684+
"""Test that structured_output works with an injected client and doesn't close it."""
685+
# Create a mock injected client
686+
mock_injected_client = unittest.mock.Mock()
687+
mock_injected_client.aio = unittest.mock.AsyncMock()
688+
689+
mock_injected_client.aio.models.generate_content.return_value = unittest.mock.Mock(
690+
parsed=weather_output.model_dump()
691+
)
692+
693+
# Create model with injected client
694+
model = GeminiModel(client=mock_injected_client, model_id=model_id)
695+
696+
messages = [{"role": "user", "content": [{"text": "Generate weather"}]}]
697+
stream = model.structured_output(type(weather_output), messages)
698+
events = await alist(stream)
699+
700+
# Verify output was generated
701+
assert len(events) == 1
702+
assert events[0] == {"output": weather_output}
703+
704+
# Verify the injected client was used
705+
mock_injected_client.aio.models.generate_content.assert_called_once()
706+
707+
708+
def test_init_with_both_client_and_client_args_raises_error():
709+
"""Test that providing both client and client_args raises ValueError."""
710+
mock_client = unittest.mock.Mock()
711+
712+
with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"):
713+
GeminiModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model")

tests/strands/models/test_openai.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
def openai_client():
1414
with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls:
1515
mock_client = unittest.mock.AsyncMock()
16-
mock_client_cls.return_value.__aenter__.return_value = mock_client
16+
# Make the mock client work as an async context manager
17+
mock_client.__aenter__ = unittest.mock.AsyncMock(return_value=mock_client)
18+
mock_client.__aexit__ = unittest.mock.AsyncMock(return_value=None)
19+
mock_client_cls.return_value = mock_client
1720
yield mock_client
1821

1922

@@ -986,3 +989,77 @@ def test_format_request_messages_drops_cache_points():
986989
]
987990

988991
assert result == expected
992+
993+
994+
@pytest.mark.asyncio
995+
async def test_stream_with_injected_client(model_id, agenerator, alist):
996+
"""Test that stream works with an injected client and doesn't close it."""
997+
# Create a mock injected client
998+
mock_injected_client = unittest.mock.AsyncMock()
999+
mock_injected_client.close = unittest.mock.AsyncMock()
1000+
1001+
mock_delta = unittest.mock.Mock(content="Hello", tool_calls=None, reasoning_content=None)
1002+
mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)])
1003+
mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)])
1004+
mock_event_3 = unittest.mock.Mock()
1005+
1006+
mock_injected_client.chat.completions.create = unittest.mock.AsyncMock(
1007+
return_value=agenerator([mock_event_1, mock_event_2, mock_event_3])
1008+
)
1009+
1010+
# Create model with injected client
1011+
model = OpenAIModel(client=mock_injected_client, model_id=model_id, params={"max_tokens": 1})
1012+
1013+
messages = [{"role": "user", "content": [{"text": "test"}]}]
1014+
response = model.stream(messages)
1015+
tru_events = await alist(response)
1016+
1017+
# Verify events were generated
1018+
assert len(tru_events) > 0
1019+
1020+
# Verify the injected client was used
1021+
mock_injected_client.chat.completions.create.assert_called_once()
1022+
1023+
# Verify the injected client was NOT closed
1024+
mock_injected_client.close.assert_not_called()
1025+
1026+
1027+
@pytest.mark.asyncio
1028+
async def test_structured_output_with_injected_client(model_id, test_output_model_cls, alist):
1029+
"""Test that structured_output works with an injected client and doesn't close it."""
1030+
# Create a mock injected client
1031+
mock_injected_client = unittest.mock.AsyncMock()
1032+
mock_injected_client.close = unittest.mock.AsyncMock()
1033+
1034+
mock_parsed_instance = test_output_model_cls(name="John", age=30)
1035+
mock_choice = unittest.mock.Mock()
1036+
mock_choice.message.parsed = mock_parsed_instance
1037+
mock_response = unittest.mock.Mock()
1038+
mock_response.choices = [mock_choice]
1039+
1040+
mock_injected_client.beta.chat.completions.parse = unittest.mock.AsyncMock(return_value=mock_response)
1041+
1042+
# Create model with injected client
1043+
model = OpenAIModel(client=mock_injected_client, model_id=model_id, params={"max_tokens": 1})
1044+
1045+
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]
1046+
stream = model.structured_output(test_output_model_cls, messages)
1047+
events = await alist(stream)
1048+
1049+
# Verify output was generated
1050+
assert len(events) == 1
1051+
assert events[0] == {"output": test_output_model_cls(name="John", age=30)}
1052+
1053+
# Verify the injected client was used
1054+
mock_injected_client.beta.chat.completions.parse.assert_called_once()
1055+
1056+
# Verify the injected client was NOT closed
1057+
mock_injected_client.close.assert_not_called()
1058+
1059+
1060+
def test_init_with_both_client_and_client_args_raises_error():
1061+
"""Test that providing both client and client_args raises ValueError."""
1062+
mock_client = unittest.mock.AsyncMock()
1063+
1064+
with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"):
1065+
OpenAIModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model")

0 commit comments

Comments
 (0)