diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 22feecf32..cf7cc604a 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -54,27 +54,44 @@ class GeminiConfig(TypedDict, total=False): def __init__( self, *, + client: Optional[genai.Client] = None, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[GeminiConfig], ) -> None: """Initialize provider instance. Args: + client: Pre-configured Gemini client to reuse across requests. + When provided, this client will be reused for all requests and will NOT be closed + by the model. The caller is responsible for managing the client lifecycle. + This is useful for: + - Injecting custom client wrappers + - Reusing connection pools within a single event loop/worker + - Centralizing observability, retries, and networking policy + Note: The client should not be shared across different asyncio event loops. client_args: Arguments for the underlying Gemini client (e.g., api_key). For a complete list of supported arguments, see https://googleapis.github.io/python-genai/. **model_config: Configuration options for the Gemini model. + + Raises: + ValueError: If both `client` and `client_args` are provided. """ validate_config_keys(model_config, GeminiModel.GeminiConfig) self.config = GeminiModel.GeminiConfig(**model_config) + # Validate that only one client configuration method is provided + if client is not None and client_args is not None and len(client_args) > 0: + raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.") + + self._custom_client = client + self.client_args = client_args or {} + # Validate gemini_tools if provided if "gemini_tools" in self.config: self._validate_gemini_tools(self.config["gemini_tools"]) logger.debug("config=<%s> | initializing", self.config) - self.client_args = client_args or {} - @override def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: ignore[override] """Update the Gemini model configuration with the provided arguments. @@ -97,6 +114,24 @@ def get_config(self) -> GeminiConfig: """ return self.config + def _get_client(self) -> genai.Client: + """Get a Gemini client for making requests. + + This method handles client lifecycle management: + - If an injected client was provided during initialization, it returns that client + without managing its lifecycle (caller is responsible for cleanup). + - Otherwise, creates a new genai.Client from client_args. + + Returns: + genai.Client: A Gemini client instance. + """ + if self._custom_client is not None: + # Use the injected client (caller manages lifecycle) + return self._custom_client + else: + # Create a new client from client_args + return genai.Client(**self.client_args) + def _format_request_content_part(self, content: ContentBlock) -> genai.types.Part: """Format content block into a Gemini part instance. @@ -382,7 +417,8 @@ async def stream( """ request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params")) - client = genai.Client(**self.client_args).aio + client = self._get_client().aio + try: response = await client.models.generate_content_stream(**request) @@ -465,7 +501,7 @@ async def structured_output( "response_schema": output_model.model_json_schema(), } request = self._format_request(prompt, None, system_prompt, params) - client = genai.Client(**self.client_args).aio + client = self._get_client().aio response = await client.models.generate_content(**request) yield {"output": output_model.model_validate(response.parsed)} diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 435c82cab..07246c5d6 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -7,7 +7,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, AsyncIterator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast import openai from openai.types.chat.parsed_chat_completion import ParsedChatCompletion @@ -55,16 +56,39 @@ class OpenAIConfig(TypedDict, total=False): model_id: str params: Optional[dict[str, Any]] - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None: + def __init__( + self, + client: Optional[Client] = None, + client_args: Optional[dict[str, Any]] = None, + **model_config: Unpack[OpenAIConfig], + ) -> None: """Initialize provider instance. Args: - client_args: Arguments for the OpenAI client. + client: Pre-configured OpenAI-compatible client to reuse across requests. + When provided, this client will be reused for all requests and will NOT be closed + by the model. The caller is responsible for managing the client lifecycle. + This is useful for: + - Injecting custom client wrappers (e.g., GuardrailsAsyncOpenAI) + - Reusing connection pools within a single event loop/worker + - Centralizing observability, retries, and networking policy + - Pointing to custom model gateways + Note: The client should not be shared across different asyncio event loops. + client_args: Arguments for the OpenAI client (legacy approach). For a complete list of supported arguments, see https://pypi.org/project/openai/. **model_config: Configuration options for the OpenAI model. + + Raises: + ValueError: If both `client` and `client_args` are provided. """ validate_config_keys(model_config, self.OpenAIConfig) self.config = dict(model_config) + + # Validate that only one client configuration method is provided + if client is not None and client_args is not None and len(client_args) > 0: + raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.") + + self._custom_client = client self.client_args = client_args or {} logger.debug("config=<%s> | initializing", self.config) @@ -422,6 +446,34 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: case _: raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + @asynccontextmanager + async def _get_client(self) -> AsyncIterator[Any]: + """Get an OpenAI client for making requests. + + This context manager handles client lifecycle management: + - If an injected client was provided during initialization, it yields that client + without closing it (caller manages lifecycle). + - Otherwise, creates a new AsyncOpenAI client from client_args and automatically + closes it when the context exits. + + Note: We create a new client per request to avoid connection sharing in the underlying + httpx client, as the asyncio event loop does not allow connections to be shared. + For more details, see https://github.com/encode/httpx/discussions/2959. + + Yields: + Client: An OpenAI-compatible client instance. + """ + if self._custom_client is not None: + # Use the injected client (caller manages lifecycle) + yield self._custom_client + else: + # Create a new client from client_args + # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying + # httpx client. The asyncio event loop does not allow connections to be shared. For more details, please + # refer to https://github.com/encode/httpx/discussions/2959. + async with openai.AsyncOpenAI(**self.client_args) as client: + yield client + @override async def stream( self, @@ -457,7 +509,7 @@ async def stream( # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx # client. The asyncio event loop does not allow connections to be shared. For more details, please refer to # https://github.com/encode/httpx/discussions/2959. - async with openai.AsyncOpenAI(**self.client_args) as client: + async with self._get_client() as client: try: response = await client.chat.completions.create(**request) except openai.BadRequestError as e: @@ -576,7 +628,7 @@ async def structured_output( # We initialize an OpenAI context on every request so as to avoid connection sharing in the underlying httpx # client. The asyncio event loop does not allow connections to be shared. For more details, please refer to # https://github.com/encode/httpx/discussions/2959. - async with openai.AsyncOpenAI(**self.client_args) as client: + async with self._get_client() as client: try: response: ParsedChatCompletion = await client.beta.chat.completions.parse( model=self.get_config()["model_id"], diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index 8e8742f94..c552a892a 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -720,3 +720,77 @@ async def test_stream_handles_non_json_error(gemini_client, model, messages, cap assert "Gemini API returned non-JSON error" in caplog.text assert f"error_message=<{error_message}>" in caplog.text + + +@pytest.mark.asyncio +async def test_stream_with_injected_client(model_id, agenerator, alist): + """Test that stream works with an injected client and doesn't close it.""" + # Create a mock injected client + mock_injected_client = unittest.mock.Mock() + mock_injected_client.aio = unittest.mock.AsyncMock() + + mock_injected_client.aio.models.generate_content_stream.return_value = agenerator( + [ + genai.types.GenerateContentResponse( + candidates=[ + genai.types.Candidate( + content=genai.types.Content( + parts=[genai.types.Part(text="Hello")], + ), + finish_reason="STOP", + ), + ], + usage_metadata=genai.types.GenerateContentResponseUsageMetadata( + prompt_token_count=1, + total_token_count=3, + ), + ), + ] + ) + + # Create model with injected client + model = GeminiModel(client=mock_injected_client, model_id=model_id) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + # Verify events were generated + assert len(tru_events) > 0 + + # Verify the injected client was used + mock_injected_client.aio.models.generate_content_stream.assert_called_once() + + +@pytest.mark.asyncio +async def test_structured_output_with_injected_client(model_id, weather_output, alist): + """Test that structured_output works with an injected client and doesn't close it.""" + # Create a mock injected client + mock_injected_client = unittest.mock.Mock() + mock_injected_client.aio = unittest.mock.AsyncMock() + + mock_injected_client.aio.models.generate_content.return_value = unittest.mock.Mock( + parsed=weather_output.model_dump() + ) + + # Create model with injected client + model = GeminiModel(client=mock_injected_client, model_id=model_id) + + messages = [{"role": "user", "content": [{"text": "Generate weather"}]}] + stream = model.structured_output(type(weather_output), messages) + events = await alist(stream) + + # Verify output was generated + assert len(events) == 1 + assert events[0] == {"output": weather_output} + + # Verify the injected client was used + mock_injected_client.aio.models.generate_content.assert_called_once() + + +def test_init_with_both_client_and_client_args_raises_error(): + """Test that providing both client and client_args raises ValueError.""" + mock_client = unittest.mock.Mock() + + with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"): + GeminiModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model") diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 0de0c4ebc..ef173d349 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -13,7 +13,10 @@ def openai_client(): with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls: mock_client = unittest.mock.AsyncMock() - mock_client_cls.return_value.__aenter__.return_value = mock_client + # Make the mock client work as an async context manager + mock_client.__aenter__ = unittest.mock.AsyncMock(return_value=mock_client) + mock_client.__aexit__ = unittest.mock.AsyncMock(return_value=None) + mock_client_cls.return_value = mock_client yield mock_client @@ -986,3 +989,77 @@ def test_format_request_messages_drops_cache_points(): ] assert result == expected + + +@pytest.mark.asyncio +async def test_stream_with_injected_client(model_id, agenerator, alist): + """Test that stream works with an injected client and doesn't close it.""" + # Create a mock injected client + mock_injected_client = unittest.mock.AsyncMock() + mock_injected_client.close = unittest.mock.AsyncMock() + + mock_delta = unittest.mock.Mock(content="Hello", tool_calls=None, reasoning_content=None) + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock() + + mock_injected_client.chat.completions.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3]) + ) + + # Create model with injected client + model = OpenAIModel(client=mock_injected_client, model_id=model_id, params={"max_tokens": 1}) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + # Verify events were generated + assert len(tru_events) > 0 + + # Verify the injected client was used + mock_injected_client.chat.completions.create.assert_called_once() + + # Verify the injected client was NOT closed + mock_injected_client.close.assert_not_called() + + +@pytest.mark.asyncio +async def test_structured_output_with_injected_client(model_id, test_output_model_cls, alist): + """Test that structured_output works with an injected client and doesn't close it.""" + # Create a mock injected client + mock_injected_client = unittest.mock.AsyncMock() + mock_injected_client.close = unittest.mock.AsyncMock() + + mock_parsed_instance = test_output_model_cls(name="John", age=30) + mock_choice = unittest.mock.Mock() + mock_choice.message.parsed = mock_parsed_instance + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + mock_injected_client.beta.chat.completions.parse = unittest.mock.AsyncMock(return_value=mock_response) + + # Create model with injected client + model = OpenAIModel(client=mock_injected_client, model_id=model_id, params={"max_tokens": 1}) + + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + # Verify output was generated + assert len(events) == 1 + assert events[0] == {"output": test_output_model_cls(name="John", age=30)} + + # Verify the injected client was used + mock_injected_client.beta.chat.completions.parse.assert_called_once() + + # Verify the injected client was NOT closed + mock_injected_client.close.assert_not_called() + + +def test_init_with_both_client_and_client_args_raises_error(): + """Test that providing both client and client_args raises ValueError.""" + mock_client = unittest.mock.AsyncMock() + + with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"): + OpenAIModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model")