diff --git a/pyproject.toml b/pyproject.toml index 59e24dac3..0540f2e73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ anthropic = ["anthropic>=0.21.0,<1.0.0"] gemini = ["google-genai>=1.32.0,<2.0.0"] litellm = ["litellm>=1.75.9,<2.0.0", "openai>=1.68.0,<3.0.0"] llamaapi = ["llama-api-client>=0.1.0,<1.0.0"] +minimax = ["openai>=1.68.0,<3.0.0"] mistral = ["mistralai>=1.8.2,<2.0.0"] ollama = ["ollama>=0.4.8,<1.0.0"] openai = ["openai>=1.68.0,<3.0.0"] @@ -83,7 +84,7 @@ bidi-io = [ bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] bidi-openai = ["websockets>=15.0.0,<17.0.0"] -all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,minimax,mistral,ollama,openai,writer,sagemaker,otel]"] bidi-all = ["strands-agents[a2a,bidi,bidi-io,bidi-gemini,bidi-openai,docs,otel]"] dev = [ diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index 2c582d116..48b893bcb 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -43,6 +43,10 @@ def __getattr__(name: str) -> Any: from .llamacpp import LlamaCppModel return LlamaCppModel + if name == "MinimaxModel": + from .minimax import MinimaxModel + + return MinimaxModel if name == "MistralModel": from .mistral import MistralModel diff --git a/src/strands/models/minimax.py b/src/strands/models/minimax.py new file mode 100644 index 000000000..0c19bc2ec --- /dev/null +++ b/src/strands/models/minimax.py @@ -0,0 +1,264 @@ +"""MiniMax model provider. + +- Docs: https://platform.minimaxi.com/document/introduction +""" + +import json +import logging +import os +import re +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar + +import openai +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import Messages +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException +from ..types.streaming import StreamEvent +from ..types.tools import ToolChoice, ToolSpec +from ._validation import validate_config_keys +from .openai import OpenAIModel + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + +# Default MiniMax API base URL +_DEFAULT_MINIMAX_BASE_URL = "https://api.minimax.io/v1" + + +class MinimaxModel(OpenAIModel): + """MiniMax model provider implementation. + + This provider extends OpenAIModel to work with MiniMax's OpenAI-compatible API. + + MiniMax provides large language models including MiniMax-M2.7 and MiniMax-M2.5-highspeed, + accessible through an OpenAI-compatible chat completions endpoint. + + Example usage:: + + from strands import Agent + from strands.models.minimax import MinimaxModel + + model = MinimaxModel(model_id="MiniMax-M2.7") + agent = Agent(model=model) + response = agent("Tell me about AI") + + Attributes: + client: The underlying OpenAI-compatible async client for MiniMax API. + """ + + class MinimaxConfig(TypedDict, total=False): + """Configuration options for MiniMax models. + + Attributes: + model_id: Model ID (e.g., "MiniMax-M2.7", "MiniMax-M2.5-highspeed"). + For a complete list of supported models, see https://platform.minimaxi.com/document/models. + params: Model parameters (e.g., max_tokens, temperature). + For a complete list of supported parameters, see + https://platform.minimaxi.com/document/chat-completion-v2. + """ + + model_id: str + params: dict[str, Any] | None + + def __init__( + self, + client: "OpenAIModel.Client | None" = None, + client_args: dict[str, Any] | None = None, + **model_config: Unpack[MinimaxConfig], + ) -> None: + """Initialize provider instance. + + If no client or client_args are provided, the provider will automatically configure + the MiniMax API base URL and read the API key from the ``MINIMAX_API_KEY`` environment + variable. + + Args: + client: Pre-configured OpenAI-compatible client to reuse across requests. + When provided, this client will be reused for all requests and will NOT be closed + by the model. The caller is responsible for managing the client lifecycle. + client_args: Arguments for the OpenAI client. + Defaults to using the MiniMax API base URL and API key from environment. + **model_config: Configuration options for the MiniMax model. + + Raises: + ValueError: If both ``client`` and ``client_args`` are provided. + """ + # Set default client_args for MiniMax if no client is provided + if client is None and client_args is None: + client_args = {} + + if client_args is not None: + client_args.setdefault("base_url", _DEFAULT_MINIMAX_BASE_URL) + client_args.setdefault("api_key", os.environ.get("MINIMAX_API_KEY", "")) + + super().__init__(client=client, client_args=client_args, **model_config) + + @override + def update_config(self, **model_config: Unpack[MinimaxConfig]) -> None: # type: ignore[override] + """Update the MiniMax model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + validate_config_keys(model_config, self.MinimaxConfig) + self.config.update(model_config) + + @override + def get_config(self) -> MinimaxConfig: + """Get the MiniMax model configuration. + + Returns: + The MiniMax model configuration. + """ + from typing import cast + + return cast(MinimaxModel.MinimaxConfig, self.config) + + @override + def format_request( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + tool_choice: ToolChoice | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Format a MiniMax-compatible chat streaming request. + + Extends the OpenAI format_request to remove empty tool lists, which + are not accepted by the MiniMax API. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + A MiniMax-compatible chat streaming request. + """ + request = super().format_request(messages, tool_specs, system_prompt, tool_choice, **kwargs) + + # MiniMax does not accept empty tools list + if not request.get("tools"): + request.pop("tools", None) + + return request + + @override + async def stream( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + *, + tool_choice: ToolChoice | None = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the MiniMax model. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + tool_choice: Selection strategy for tool invocation. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Formatted message chunks from the model. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the request is throttled by MiniMax (rate limits). + """ + async for event in super().stream( + messages, tool_specs, system_prompt, tool_choice=tool_choice, **kwargs + ): + yield event + + @staticmethod + def _clean_response_content(content: str) -> str: + """Clean MiniMax model output for structured parsing. + + MiniMax models may include: + - Reasoning content wrapped in ```` tags + - JSON wrapped in markdown code blocks (````json ... ````) + + This method strips those wrappers and returns only the meaningful content. + + Args: + content: Raw model response that may contain think tags or code blocks. + + Returns: + Cleaned content ready for JSON parsing. + """ + # Strip ... tags + content = re.sub(r".*?", "", content, flags=re.DOTALL).strip() + + # Strip markdown code blocks (```json ... ``` or ``` ... ```) + content = re.sub(r"^```(?:json)?\s*\n?", "", content) + content = re.sub(r"\n?```\s*$", "", content) + + return content.strip() + + @override + async def structured_output( + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: + """Get structured output from the MiniMax model. + + Uses a regular chat completion with ``response_format`` set to ``json_object`` + instead of the beta parse API, since MiniMax models may include ```` tags + that interfere with the beta parser. + + Args: + output_model: The output model to use for the agent. + prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Model events with the last being the structured output. + + Raises: + ContextWindowOverflowException: If the input exceeds the model's context window. + ModelThrottledException: If the request is throttled by MiniMax (rate limits). + """ + request = self.format_request(prompt, system_prompt=system_prompt) + request["stream"] = False + request["response_format"] = {"type": "json_object"} + + # Remove stream_options for non-streaming request + request.pop("stream_options", None) + + # Add schema hint as a user message so the model knows the expected format + # (MiniMax only allows system messages at the beginning of the conversation) + schema_hint = { + "role": "user", + "content": f"Respond with a JSON object matching this schema: {json.dumps(output_model.model_json_schema())}", + } + request["messages"].append(schema_hint) + + async with self._get_client() as client: + try: + response = await client.chat.completions.create(**request) + except openai.BadRequestError as e: + if hasattr(e, "code") and e.code == "context_length_exceeded": + raise ContextWindowOverflowException(str(e)) from e + raise + except openai.RateLimitError as e: + raise ModelThrottledException(str(e)) from e + + content = response.choices[0].message.content or "" + content = self._clean_response_content(content) + + try: + parsed = output_model.model_validate_json(content) + yield {"output": parsed} + except Exception as e: + raise ValueError(f"Failed to parse MiniMax response into {output_model.__name__}: {e}") from e diff --git a/tests/strands/models/test_minimax.py b/tests/strands/models/test_minimax.py new file mode 100644 index 000000000..1be7fc0cb --- /dev/null +++ b/tests/strands/models/test_minimax.py @@ -0,0 +1,458 @@ +import unittest.mock + +import openai +import pydantic +import pytest + +import strands +from strands.models.minimax import MinimaxModel +from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException + + +@pytest.fixture +def openai_client(): + with unittest.mock.patch.object(strands.models.openai.openai, "AsyncOpenAI") as mock_client_cls: + mock_client = unittest.mock.AsyncMock() + # Make the mock client work as an async context manager + mock_client.__aenter__ = unittest.mock.AsyncMock(return_value=mock_client) + mock_client.__aexit__ = unittest.mock.AsyncMock(return_value=None) + mock_client_cls.return_value = mock_client + yield mock_client + + +@pytest.fixture +def model_id(): + return "MiniMax-M2.7" + + +@pytest.fixture +def model(openai_client, model_id): + _ = openai_client + + return MinimaxModel(model_id=model_id, params={"max_tokens": 1}) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + +def test__init__(openai_client, model_id): + _ = openai_client + + model = MinimaxModel(model_id=model_id, params={"max_tokens": 1}) + + tru_config = model.get_config() + exp_config = {"model_id": "MiniMax-M2.7", "params": {"max_tokens": 1}} + + assert tru_config == exp_config + + +def test__init__default_client_args(openai_client): + """Test that default client_args include MiniMax base_url and API key.""" + _ = openai_client + + with unittest.mock.patch.dict("os.environ", {"MINIMAX_API_KEY": "test-key"}): + model = MinimaxModel(model_id="MiniMax-M2.7") + + assert model.client_args["base_url"] == "https://api.minimax.io/v1" + assert model.client_args["api_key"] == "test-key" + + +def test__init__custom_client_args(openai_client): + """Test that custom client_args override defaults.""" + _ = openai_client + + model = MinimaxModel( + client_args={"base_url": "https://custom.api.com/v1", "api_key": "custom-key"}, + model_id="MiniMax-M2.7", + ) + + assert model.client_args["base_url"] == "https://custom.api.com/v1" + assert model.client_args["api_key"] == "custom-key" + + +def test__init__with_injected_client(openai_client): + """Test initialization with a pre-configured client.""" + mock_client = unittest.mock.AsyncMock() + + model = MinimaxModel(client=mock_client, model_id="MiniMax-M2.7") + + assert model._custom_client is mock_client + + +def test__init__with_both_client_and_client_args_raises_error(): + """Test that providing both client and client_args raises ValueError.""" + mock_client = unittest.mock.AsyncMock() + + with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"): + MinimaxModel(client=mock_client, client_args={"api_key": "test"}, model_id="MiniMax-M2.7") + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +def test_format_request_includes_stream_options(model, messages, tool_specs, system_prompt): + """Test that stream_options is included in the request for usage tracking.""" + tru_request = model.format_request(messages, tool_specs, system_prompt) + + assert tru_request["stream_options"] == {"include_usage": True} + assert tru_request["model"] == "MiniMax-M2.7" + assert tru_request["stream"] is True + + +def test_format_request_no_empty_tools(model, messages, system_prompt): + """Test that empty tools list is removed from the request.""" + tru_request = model.format_request(messages, None, system_prompt) + + assert "tools" not in tru_request + + +def test_format_request_with_tools(model, messages, tool_specs, system_prompt): + """Test that non-empty tools list is preserved.""" + tru_request = model.format_request(messages, tool_specs, system_prompt) + + assert "tools" in tru_request + assert len(tru_request["tools"]) == 1 + assert tru_request["tools"][0]["function"]["name"] == "test_tool" + + +def test_format_request(model, messages, tool_specs, system_prompt): + tru_request = model.format_request(messages, tool_specs, system_prompt) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "test", "type": "text"}], + "role": "user", + }, + ], + "model": "MiniMax-M2.7", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "max_tokens": 1, + } + assert tru_request == exp_request + + +@pytest.mark.asyncio +async def test_stream(openai_client, model_id, model, agenerator, alist): + mock_delta = unittest.mock.Mock(content="Hello", tool_calls=None, reasoning_content=None) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock() + + openai_client.chat.completions.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3]) + ) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + # Should have: messageStart, contentBlockStart, 2x contentBlockDelta, contentBlockStop, messageStop + assert len(tru_events) >= 4 + assert tru_events[0] == {"messageStart": {"role": "assistant"}} + + # Verify the request was made with stream_options for usage tracking + call_kwargs = openai_client.chat.completions.create.call_args[1] + assert call_kwargs["stream_options"] == {"include_usage": True} + + +@pytest.mark.asyncio +async def test_stream_context_overflow_exception(openai_client, model, messages): + """Test that context overflow errors are properly converted.""" + mock_error = openai.BadRequestError( + message="This model's maximum context length exceeded.", + response=unittest.mock.MagicMock(), + body={"error": {"code": "context_length_exceeded"}}, + ) + mock_error.code = "context_length_exceeded" + + openai_client.chat.completions.create.side_effect = mock_error + + with pytest.raises(ContextWindowOverflowException): + async for _ in model.stream(messages): + pass + + +@pytest.mark.asyncio +async def test_stream_rate_limit_as_throttle(openai_client, model, messages): + """Test that rate limit errors are converted to ModelThrottledException.""" + mock_error = openai.RateLimitError( + message="Rate limit reached.", + response=unittest.mock.MagicMock(), + body={"error": {"code": "rate_limit_exceeded"}}, + ) + mock_error.code = "rate_limit_exceeded" + + openai_client.chat.completions.create.side_effect = mock_error + + with pytest.raises(ModelThrottledException): + async for _ in model.stream(messages): + pass + + +@pytest.mark.asyncio +async def test_stream_with_tool_calls(openai_client, model, agenerator, alist): + mock_tool_call = unittest.mock.Mock(index=0) + mock_delta_1 = unittest.mock.Mock(content="I'll help", tool_calls=[mock_tool_call], reasoning_content=None) + mock_delta_2 = unittest.mock.Mock(content="", tool_calls=None, reasoning_content=None) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_2)]) + mock_event_3 = unittest.mock.Mock() + + openai_client.chat.completions.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3]) + ) + + messages = [{"role": "user", "content": [{"text": "use the tool"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + # Should contain tool_use stop reason + stop_events = [e for e in tru_events if "messageStop" in e] + assert len(stop_events) == 1 + assert stop_events[0]["messageStop"]["stopReason"] == "tool_use" + + +@pytest.mark.asyncio +async def test_structured_output(openai_client, model, test_output_model_cls, alist): + """Test structured output using regular completion with response_format.""" + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + mock_choice = unittest.mock.Mock() + mock_choice.message.content = '{"name": "John", "age": 30}' + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + openai_client.chat.completions.create = unittest.mock.AsyncMock(return_value=mock_response) + + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + tru_result = events[-1] + exp_result = {"output": test_output_model_cls(name="John", age=30)} + assert tru_result == exp_result + + +@pytest.mark.asyncio +async def test_structured_output_with_think_tags(openai_client, model, test_output_model_cls, alist): + """Test structured output strips tags from MiniMax response.""" + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + mock_choice = unittest.mock.Mock() + mock_choice.message.content = '\nLet me think about this.\n\n{"name": "Alice", "age": 25}' + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + + openai_client.chat.completions.create = unittest.mock.AsyncMock(return_value=mock_response) + + stream = model.structured_output(test_output_model_cls, messages) + events = await alist(stream) + + tru_result = events[-1] + exp_result = {"output": test_output_model_cls(name="Alice", age=25)} + assert tru_result == exp_result + + +def test_config_validation_warns_on_unknown_keys(openai_client, captured_warnings): + """Test that unknown config keys emit a warning.""" + MinimaxModel(model_id="MiniMax-M2.7", invalid_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "invalid_param" in str(captured_warnings[0].message) + + +def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings): + """Test that update_config warns on unknown keys.""" + model.update_config(wrong_param="test") + + assert len(captured_warnings) == 1 + assert "Invalid configuration parameters" in str(captured_warnings[0].message) + assert "wrong_param" in str(captured_warnings[0].message) + + +def test_format_request_messages(system_prompt): + """Test message formatting inherits from OpenAI properly.""" + messages = [ + { + "content": [{"text": "hello"}], + "role": "user", + }, + ] + + tru_result = MinimaxModel.format_request_messages(messages, system_prompt) + exp_result = [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "hello", "type": "text"}], + "role": "user", + }, + ] + assert tru_result == exp_result + + +def test_format_request_message_content(): + """Test content formatting inherits from OpenAI properly.""" + content = {"text": "hello"} + tru_result = MinimaxModel.format_request_message_content(content) + exp_result = {"type": "text", "text": "hello"} + assert tru_result == exp_result + + +def test_format_request_message_tool_call(): + """Test tool call formatting inherits from OpenAI properly.""" + tool_use = { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + } + + tru_result = MinimaxModel.format_request_message_tool_call(tool_use) + exp_result = { + "function": { + "arguments": '{"expression": "2+2"}', + "name": "calculator", + }, + "id": "c1", + "type": "function", + } + assert tru_result == exp_result + + +@pytest.mark.parametrize( + ("event", "exp_chunk"), + [ + # Message start + ( + {"chunk_type": "message_start"}, + {"messageStart": {"role": "assistant"}}, + ), + # Content Delta - Text + ( + {"chunk_type": "content_delta", "data_type": "text", "data": "hello"}, + {"contentBlockDelta": {"delta": {"text": "hello"}}}, + ), + # Content Stop + ( + {"chunk_type": "content_stop"}, + {"contentBlockStop": {}}, + ), + # Message Stop - End Turn + ( + {"chunk_type": "message_stop", "data": "stop"}, + {"messageStop": {"stopReason": "end_turn"}}, + ), + ], +) +def test_format_chunk(event, exp_chunk, model): + tru_chunk = model.format_chunk(event) + assert tru_chunk == exp_chunk + + +@pytest.mark.parametrize( + "input_text, expected", + [ + ("\nSome reasoning.\n\n{}", "{}"), + ("quick thoughtHello", "Hello"), + ("No think tags here", "No think tags here"), + ('\nMultiple\nlines\n\n{"key": "value"}', '{"key": "value"}'), + ('```json\n{"key": "value"}\n```', '{"key": "value"}'), + ('```\n{"key": "value"}\n```', '{"key": "value"}'), + ('\nthinking\n\n```json\n{"key": "value"}\n```', '{"key": "value"}'), + ], +) +def test_clean_response_content(input_text, expected): + """Test that _clean_response_content removes think tags and code blocks.""" + assert MinimaxModel._clean_response_content(input_text) == expected + + +@pytest.mark.asyncio +async def test_stream_with_injected_client(model_id, agenerator, alist): + """Test that stream works with an injected client.""" + mock_injected_client = unittest.mock.AsyncMock() + mock_injected_client.close = unittest.mock.AsyncMock() + + mock_delta = unittest.mock.Mock(content="Hello", tool_calls=None, reasoning_content=None) + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock() + + mock_injected_client.chat.completions.create = unittest.mock.AsyncMock( + return_value=agenerator([mock_event_1, mock_event_2, mock_event_3]) + ) + + model = MinimaxModel(client=mock_injected_client, model_id=model_id, params={"max_tokens": 1}) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages) + tru_events = await alist(response) + + assert len(tru_events) > 0 + mock_injected_client.chat.completions.create.assert_called_once() + mock_injected_client.close.assert_not_called() diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 15161b9cb..e6894f767 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -13,6 +13,7 @@ from strands.models.gemini import GeminiModel from strands.models.litellm import LiteLLMModel from strands.models.llamaapi import LlamaAPIModel +from strands.models.minimax import MinimaxModel from strands.models.mistral import MistralModel from strands.models.ollama import OllamaModel from strands.models.openai import OpenAIModel @@ -103,6 +104,16 @@ def __init__(self): }, ), ) +minimax = ProviderInfo( + id="minimax", + environment_variable="MINIMAX_API_KEY", + factory=lambda: MinimaxModel( + model_id="MiniMax-M2.7", + client_args={ + "api_key": os.getenv("MINIMAX_API_KEY"), + }, + ), +) mistral = ProviderInfo( id="mistral", environment_variable="MISTRAL_API_KEY", @@ -169,6 +180,7 @@ def __init__(self): gemini, llama, litellm, + minimax, mistral, openai, openai_responses, diff --git a/tests_integ/models/test_model_minimax.py b/tests_integ/models/test_model_minimax.py new file mode 100644 index 000000000..ff1fa81e2 --- /dev/null +++ b/tests_integ/models/test_model_minimax.py @@ -0,0 +1,74 @@ +import os + +import pydantic +import pytest + +import strands +from strands import Agent +from strands.models.minimax import MinimaxModel +from tests_integ.models import providers + +# these tests only run if we have the minimax api key +pytestmark = providers.minimax.mark + + +@pytest.fixture +def model(): + return MinimaxModel( + model_id="MiniMax-M2.7", + client_args={ + "api_key": os.getenv("MINIMAX_API_KEY"), + }, + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def agent(model, tools): + return Agent(model=model, tools=tools) + + +@pytest.fixture +def weather(): + class Weather(pydantic.BaseModel): + """Extract time and weather values.""" + + time: str = pydantic.Field(description="The time value only, e.g. '14:30' not 'The time is 14:30'") + weather: str = pydantic.Field( + description="The weather condition only, e.g. 'rainy' not 'the weather is rainy'" + ) + + return Weather(time="12:00", weather="sunny") + + +def test_agent_invoke(agent, model): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_invoke_async(agent, model): + result = await agent.invoke_async("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +def test_agent_structured_output(model, weather): + agent = Agent(model=model) + result = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") + assert result == weather