Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions src/strands/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,44 @@ class GeminiConfig(TypedDict, total=False):
def __init__(
self,
*,
client: Optional[genai.Client] = None,
client_args: Optional[dict[str, Any]] = None,
**model_config: Unpack[GeminiConfig],
) -> None:
"""Initialize provider instance.

Args:
client: Pre-configured Gemini client to reuse across requests.
When provided, this client will be reused for all requests and will NOT be closed
by the model. The caller is responsible for managing the client lifecycle.
This is useful for:
- Injecting custom client wrappers
- Reusing connection pools within a single event loop/worker
- Centralizing observability, retries, and networking policy
Note: The client should not be shared across different asyncio event loops.
client_args: Arguments for the underlying Gemini client (e.g., api_key).
For a complete list of supported arguments, see https://googleapis.github.io/python-genai/.
**model_config: Configuration options for the Gemini model.

Raises:
ValueError: If both `client` and `client_args` are provided.
"""
validate_config_keys(model_config, GeminiModel.GeminiConfig)
self.config = GeminiModel.GeminiConfig(**model_config)

# Validate that only one client configuration method is provided
if client is not None and client_args is not None and len(client_args) > 0:
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")

self._custom_client = client
self.client_args = client_args or {}

# Validate gemini_tools if provided
if "gemini_tools" in self.config:
self._validate_gemini_tools(self.config["gemini_tools"])

logger.debug("config=<%s> | initializing", self.config)

self.client_args = client_args or {}

@override
def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: ignore[override]
"""Update the Gemini model configuration with the provided arguments.
Expand All @@ -97,6 +114,24 @@ def get_config(self) -> GeminiConfig:
"""
return self.config

def _get_client(self) -> genai.Client:
"""Get a Gemini client for making requests.

This method handles client lifecycle management:
- If an injected client was provided during initialization, it returns that client
without managing its lifecycle (caller is responsible for cleanup).
- Otherwise, creates a new genai.Client from client_args.

Returns:
genai.Client: A Gemini client instance.
"""
if self._custom_client is not None:
# Use the injected client (caller manages lifecycle)
return self._custom_client
else:
# Create a new client from client_args
return genai.Client(**self.client_args)

def _format_request_content_part(self, content: ContentBlock) -> genai.types.Part:
"""Format content block into a Gemini part instance.

Expand Down Expand Up @@ -382,7 +417,8 @@ async def stream(
"""
request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params"))

client = genai.Client(**self.client_args).aio
client = self._get_client().aio

try:
response = await client.models.generate_content_stream(**request)

Expand Down Expand Up @@ -465,7 +501,7 @@ async def structured_output(
"response_schema": output_model.model_json_schema(),
}
request = self._format_request(prompt, None, system_prompt, params)
client = genai.Client(**self.client_args).aio
client = self._get_client().aio
response = await client.models.generate_content(**request)
yield {"output": output_model.model_validate(response.parsed)}

Expand Down
62 changes: 57 additions & 5 deletions src/strands/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import json
import logging
import mimetypes
from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, AsyncIterator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast

import openai
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
Expand Down Expand Up @@ -55,16 +56,39 @@ class OpenAIConfig(TypedDict, total=False):
model_id: str
params: Optional[dict[str, Any]]

def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None:
def __init__(
self,
client: Optional[Client] = None,
client_args: Optional[dict[str, Any]] = None,
**model_config: Unpack[OpenAIConfig],
) -> None:
"""Initialize provider instance.

Args:
client_args: Arguments for the OpenAI client.
client: Pre-configured OpenAI-compatible client to reuse across requests.
When provided, this client will be reused for all requests and will NOT be closed
by the model. The caller is responsible for managing the client lifecycle.
This is useful for:
- Injecting custom client wrappers (e.g., GuardrailsAsyncOpenAI)
- Reusing connection pools within a single event loop/worker
- Centralizing observability, retries, and networking policy
- Pointing to custom model gateways
Note: The client should not be shared across different asyncio event loops.
client_args: Arguments for the OpenAI client (legacy approach).
For a complete list of supported arguments, see https://pypi.org/project/openai/.
**model_config: Configuration options for the OpenAI model.

Raises:
ValueError: If both `client` and `client_args` are provided.
"""
validate_config_keys(model_config, self.OpenAIConfig)
self.config = dict(model_config)

# Validate that only one client configuration method is provided
if client is not None and client_args is not None and len(client_args) > 0:
raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.")

self._custom_client = client
self.client_args = client_args or {}

logger.debug("config=<%s> | initializing", self.config)
Expand Down Expand Up @@ -422,6 +446,34 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
case _:
raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")

@asynccontextmanager
async def _get_client(self) -> AsyncIterator[Any]:
"""Get an OpenAI client for making requests.

This context manager handles client lifecycle management:
- If an injected client was provided during initialization, it yields that client
without closing it (caller manages lifecycle).
- Otherwise, creates a new AsyncOpenAI client from client_args and automatically
closes it when the context exits.

Note: We create a new client per request to avoid connection sharing in the underlying
httpx client, as the asyncio event loop does not allow connections to be shared.
For more details, see https://github.com/encode/httpx/discussions/2959.

Yields:
Client: An OpenAI-compatible client instance.
"""
if self._custom_client is not None:
# Use the injected client (caller manages lifecycle)
yield self._custom_client
else:
# Create a new client from client_args
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying
# httpx client. The asyncio event loop does not allow connections to be shared. For more details, please
# refer to https://github.com/encode/httpx/discussions/2959.
async with openai.AsyncOpenAI(**self.client_args) as client:
yield client

@override
async def stream(
self,
Expand Down Expand Up @@ -457,7 +509,7 @@ async def stream(
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
# https://github.com/encode/httpx/discussions/2959.
async with openai.AsyncOpenAI(**self.client_args) as client:
async with self._get_client() as client:
try:
response = await client.chat.completions.create(**request)
except openai.BadRequestError as e:
Expand Down Expand Up @@ -576,7 +628,7 @@ async def structured_output(
# We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx
# client. The asyncio event loop does not allow connections to be shared. For more details, please refer to
# https://github.com/encode/httpx/discussions/2959.
async with openai.AsyncOpenAI(**self.client_args) as client:
async with self._get_client() as client:
try:
response: ParsedChatCompletion = await client.beta.chat.completions.parse(
model=self.get_config()["model_id"],
Expand Down
74 changes: 74 additions & 0 deletions tests/strands/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,3 +720,77 @@ async def test_stream_handles_non_json_error(gemini_client, model, messages, cap

assert "Gemini API returned non-JSON error" in caplog.text
assert f"error_message=<{error_message}>" in caplog.text


@pytest.mark.asyncio
async def test_stream_with_injected_client(model_id, agenerator, alist):
"""Test that stream works with an injected client and doesn't close it."""
# Create a mock injected client
mock_injected_client = unittest.mock.Mock()
mock_injected_client.aio = unittest.mock.AsyncMock()

mock_injected_client.aio.models.generate_content_stream.return_value = agenerator(
[
genai.types.GenerateContentResponse(
candidates=[
genai.types.Candidate(
content=genai.types.Content(
parts=[genai.types.Part(text="Hello")],
),
finish_reason="STOP",
),
],
usage_metadata=genai.types.GenerateContentResponseUsageMetadata(
prompt_token_count=1,
total_token_count=3,
),
),
]
)

# Create model with injected client
model = GeminiModel(client=mock_injected_client, model_id=model_id)

messages = [{"role": "user", "content": [{"text": "test"}]}]
response = model.stream(messages)
tru_events = await alist(response)

# Verify events were generated
assert len(tru_events) > 0

# Verify the injected client was used
mock_injected_client.aio.models.generate_content_stream.assert_called_once()


@pytest.mark.asyncio
async def test_structured_output_with_injected_client(model_id, weather_output, alist):
"""Test that structured_output works with an injected client and doesn't close it."""
# Create a mock injected client
mock_injected_client = unittest.mock.Mock()
mock_injected_client.aio = unittest.mock.AsyncMock()

mock_injected_client.aio.models.generate_content.return_value = unittest.mock.Mock(
parsed=weather_output.model_dump()
)

# Create model with injected client
model = GeminiModel(client=mock_injected_client, model_id=model_id)

messages = [{"role": "user", "content": [{"text": "Generate weather"}]}]
stream = model.structured_output(type(weather_output), messages)
events = await alist(stream)

# Verify output was generated
assert len(events) == 1
assert events[0] == {"output": weather_output}

# Verify the injected client was used
mock_injected_client.aio.models.generate_content.assert_called_once()


def test_init_with_both_client_and_client_args_raises_error():
"""Test that providing both client and client_args raises ValueError."""
mock_client = unittest.mock.Mock()

with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"):
GeminiModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model")
79 changes: 78 additions & 1 deletion tests/strands/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
def openai_client():
with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls:
mock_client = unittest.mock.AsyncMock()
mock_client_cls.return_value.__aenter__.return_value = mock_client
# Make the mock client work as an async context manager
mock_client.__aenter__ = unittest.mock.AsyncMock(return_value=mock_client)
mock_client.__aexit__ = unittest.mock.AsyncMock(return_value=None)
mock_client_cls.return_value = mock_client
yield mock_client


Expand Down Expand Up @@ -986,3 +989,77 @@ def test_format_request_messages_drops_cache_points():
]

assert result == expected


@pytest.mark.asyncio
async def test_stream_with_injected_client(model_id, agenerator, alist):
"""Test that stream works with an injected client and doesn't close it."""
# Create a mock injected client
mock_injected_client = unittest.mock.AsyncMock()
mock_injected_client.close = unittest.mock.AsyncMock()

mock_delta = unittest.mock.Mock(content="Hello", tool_calls=None, reasoning_content=None)
mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)])
mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)])
mock_event_3 = unittest.mock.Mock()

mock_injected_client.chat.completions.create = unittest.mock.AsyncMock(
return_value=agenerator([mock_event_1, mock_event_2, mock_event_3])
)

# Create model with injected client
model = OpenAIModel(client=mock_injected_client, model_id=model_id, params={"max_tokens": 1})

messages = [{"role": "user", "content": [{"text": "test"}]}]
response = model.stream(messages)
tru_events = await alist(response)

# Verify events were generated
assert len(tru_events) > 0

# Verify the injected client was used
mock_injected_client.chat.completions.create.assert_called_once()

# Verify the injected client was NOT closed
mock_injected_client.close.assert_not_called()


@pytest.mark.asyncio
async def test_structured_output_with_injected_client(model_id, test_output_model_cls, alist):
"""Test that structured_output works with an injected client and doesn't close it."""
# Create a mock injected client
mock_injected_client = unittest.mock.AsyncMock()
mock_injected_client.close = unittest.mock.AsyncMock()

mock_parsed_instance = test_output_model_cls(name="John", age=30)
mock_choice = unittest.mock.Mock()
mock_choice.message.parsed = mock_parsed_instance
mock_response = unittest.mock.Mock()
mock_response.choices = [mock_choice]

mock_injected_client.beta.chat.completions.parse = unittest.mock.AsyncMock(return_value=mock_response)

# Create model with injected client
model = OpenAIModel(client=mock_injected_client, model_id=model_id, params={"max_tokens": 1})

messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]
stream = model.structured_output(test_output_model_cls, messages)
events = await alist(stream)

# Verify output was generated
assert len(events) == 1
assert events[0] == {"output": test_output_model_cls(name="John", age=30)}

# Verify the injected client was used
mock_injected_client.beta.chat.completions.parse.assert_called_once()

# Verify the injected client was NOT closed
mock_injected_client.close.assert_not_called()


def test_init_with_both_client_and_client_args_raises_error():
"""Test that providing both client and client_args raises ValueError."""
mock_client = unittest.mock.AsyncMock()

with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"):
OpenAIModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model")
Loading