From a8a27557388ce67acdb8296386c6b3956033db76 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Fri, 24 Apr 2026 17:08:06 +0200 Subject: [PATCH 1/3] feat: add support for openai-compatible APIs as vlm_provider and image_qa_provider --- src/askui/model_providers/__init__.py | 12 + .../ollama_image_qa_provider.py | 49 ++ .../model_providers/ollama_vlm_provider.py | 49 ++ .../openai_image_qa_provider.py | 78 +++ .../model_providers/openai_vlm_provider.py | 121 +++++ src/askui/models/openai/__init__.py | 1 + src/askui/models/openai/get_model.py | 169 +++++++ src/askui/models/openai/messages_api.py | 332 +++++++++++++ src/askui/utils/model_pricing.py | 18 + .../test_ollama_image_qa_provider.py | 64 +++ .../test_ollama_vlm_provider.py | 50 ++ .../test_openai_image_qa_provider.py | 56 +++ .../test_openai_vlm_provider.py | 43 ++ tests/unit/models/openai/__init__.py | 0 tests/unit/models/openai/test_get_model.py | 96 ++++ tests/unit/models/openai/test_messages_api.py | 443 ++++++++++++++++++ 16 files changed, 1581 insertions(+) create mode 100644 src/askui/model_providers/ollama_image_qa_provider.py create mode 100644 src/askui/model_providers/ollama_vlm_provider.py create mode 100644 src/askui/model_providers/openai_image_qa_provider.py create mode 100644 src/askui/model_providers/openai_vlm_provider.py create mode 100644 src/askui/models/openai/__init__.py create mode 100644 src/askui/models/openai/get_model.py create mode 100644 src/askui/models/openai/messages_api.py create mode 100644 tests/unit/model_providers/test_ollama_image_qa_provider.py create mode 100644 tests/unit/model_providers/test_ollama_vlm_provider.py create mode 100644 tests/unit/model_providers/test_openai_image_qa_provider.py create mode 100644 tests/unit/model_providers/test_openai_vlm_provider.py create mode 100644 tests/unit/models/openai/__init__.py create mode 100644 tests/unit/models/openai/test_get_model.py create mode 100644 tests/unit/models/openai/test_messages_api.py diff --git a/src/askui/model_providers/__init__.py b/src/askui/model_providers/__init__.py index add59506..b804e473 100644 --- a/src/askui/model_providers/__init__.py +++ b/src/askui/model_providers/__init__.py @@ -12,6 +12,10 @@ - `AnthropicVlmProvider` — VLM via direct Anthropic API - `AnthropicImageQAProvider` — image Q&A via direct Anthropic API - `GoogleImageQAProvider` — image Q&A via Google Gemini API (direct, no proxy) +- `OpenAIVlmProvider` — VLM via any OpenAI-compatible API +- `OpenAIImageQAProvider` — image Q&A via any OpenAI-compatible API +- `OllamaVlmProvider` — VLM via local Ollama instance (OpenAI-compatible) +- `OllamaImageQAProvider` — image Q&A via local Ollama instance (OpenAI-compatible) """ from askui.model_providers.anthropic_image_qa_provider import AnthropicImageQAProvider @@ -22,6 +26,10 @@ from askui.model_providers.detection_provider import DetectionProvider from askui.model_providers.google_image_qa_provider import GoogleImageQAProvider from askui.model_providers.image_qa_provider import ImageQAProvider +from askui.model_providers.ollama_image_qa_provider import OllamaImageQAProvider +from askui.model_providers.ollama_vlm_provider import OllamaVlmProvider +from askui.model_providers.openai_image_qa_provider import OpenAIImageQAProvider +from askui.model_providers.openai_vlm_provider import OpenAIVlmProvider from askui.model_providers.vlm_provider import VlmProvider from askui.utils.model_pricing import ModelPricing @@ -35,5 +43,9 @@ "GoogleImageQAProvider", "ImageQAProvider", "ModelPricing", + "OllamaImageQAProvider", + "OllamaVlmProvider", + "OpenAIImageQAProvider", + "OpenAIVlmProvider", "VlmProvider", ] diff --git a/src/askui/model_providers/ollama_image_qa_provider.py b/src/askui/model_providers/ollama_image_qa_provider.py new file mode 100644 index 00000000..61f10df2 --- /dev/null +++ b/src/askui/model_providers/ollama_image_qa_provider.py @@ -0,0 +1,49 @@ +"""OllamaImageQAProvider — image Q&A via a local Ollama instance.""" + +from openai import OpenAI + +from askui.model_providers.openai_image_qa_provider import OpenAIImageQAProvider + +_DEFAULT_BASE_URL = "http://localhost:11434/v1" +_DEFAULT_MODEL_ID = "qwen3.5" + + +class OllamaImageQAProvider(OpenAIImageQAProvider): + """Image Q&A provider that routes requests to a local Ollama instance. + + Thin convenience wrapper around `OpenAIImageQAProvider` with Ollama + defaults (``base_url``, ``api_key``, ``model_id``). + + Args: + model_id (str, optional): Ollama model to use. Defaults to + ``"qwen3.5"``. + base_url (str, optional): Base URL for the Ollama OpenAI-compatible + API. Defaults to ``"http://localhost:11434/v1"``. + client (`OpenAI` | None, optional): Pre-configured OpenAI client. + If provided, ``base_url`` is ignored. + + Example: + ```python + from askui import AgentSettings, ComputerAgent + from askui.model_providers import OllamaImageQAProvider + + agent = ComputerAgent(settings=AgentSettings( + image_qa_provider=OllamaImageQAProvider( + model_id="llava", + ) + )) + ``` + """ + + def __init__( + self, + model_id: str = _DEFAULT_MODEL_ID, + base_url: str = _DEFAULT_BASE_URL, + client: OpenAI | None = None, + ) -> None: + super().__init__( + model_id=model_id, + api_key="ollama", # Ollama requires no auth; OpenAI SDK needs a value + base_url=base_url, + client=client, + ) diff --git a/src/askui/model_providers/ollama_vlm_provider.py b/src/askui/model_providers/ollama_vlm_provider.py new file mode 100644 index 00000000..e06fa408 --- /dev/null +++ b/src/askui/model_providers/ollama_vlm_provider.py @@ -0,0 +1,49 @@ +"""OllamaVlmProvider — VLM access via a local Ollama instance.""" + +from openai import OpenAI + +from askui.model_providers.openai_vlm_provider import OpenAIVlmProvider + +_DEFAULT_BASE_URL = "http://localhost:11434/v1" +_DEFAULT_MODEL_ID = "qwen3.5" + + +class OllamaVlmProvider(OpenAIVlmProvider): + """VLM provider that routes requests to a local Ollama instance. + + Thin convenience wrapper around `OpenAIVlmProvider` with Ollama + defaults (``base_url``, ``api_key``, ``model_id``). + + Args: + model_id (str, optional): Ollama model to use. Defaults to + ``"qwen3.5"``. + base_url (str, optional): Base URL for the Ollama OpenAI-compatible + API. Defaults to ``"http://localhost:11434/v1"``. + client (`OpenAI` | None, optional): Pre-configured OpenAI client. + If provided, ``base_url`` is ignored. + + Example: + ```python + from askui import AgentSettings, ComputerAgent + from askui.model_providers import OllamaVlmProvider + + agent = ComputerAgent(settings=AgentSettings( + vlm_provider=OllamaVlmProvider( + model_id="qwen3.5", + ) + )) + ``` + """ + + def __init__( + self, + model_id: str = _DEFAULT_MODEL_ID, + base_url: str = _DEFAULT_BASE_URL, + client: OpenAI | None = None, + ) -> None: + super().__init__( + model_id=model_id, + api_key="ollama", # Ollama requires no auth; OpenAI SDK needs a value + base_url=base_url, + client=client, + ) diff --git a/src/askui/model_providers/openai_image_qa_provider.py b/src/askui/model_providers/openai_image_qa_provider.py new file mode 100644 index 00000000..06e2e45f --- /dev/null +++ b/src/askui/model_providers/openai_image_qa_provider.py @@ -0,0 +1,78 @@ +"""OpenAIImageQAProvider — image Q&A via any OpenAI-compatible API.""" + +from functools import cached_property +from typing import Type + +from openai import OpenAI +from typing_extensions import override + +from askui.model_providers.image_qa_provider import ImageQAProvider +from askui.models.openai.get_model import OpenAIGetModel +from askui.models.shared.settings import GetSettings +from askui.models.types.response_schemas import ResponseSchema +from askui.utils.source_utils import Source + + +class OpenAIImageQAProvider(ImageQAProvider): + """Image Q&A provider for any OpenAI-compatible API. + + Works with OpenAI, Ollama, vLLM, LM Studio, Together AI, and any + other service that exposes an OpenAI-compatible ``/v1/chat/completions`` + endpoint. + + Args: + model_id (str): Model name to use. + api_key (str | None, optional): API key. Reads ``OPENAI_API_KEY`` + from the environment if not provided. + base_url (str | None, optional): Base URL for the API. Defaults + to the OpenAI API (``https://api.openai.com/v1``). + client (`OpenAI` | None, optional): Pre-configured OpenAI client. + If provided, ``api_key`` and ``base_url`` are ignored. + + Example: + ```python + from askui import AgentSettings, ComputerAgent + from askui.model_providers import OpenAIImageQAProvider + + agent = ComputerAgent(settings=AgentSettings( + image_qa_provider=OpenAIImageQAProvider( + model_id="gpt-4o", + api_key="sk-...", + ) + )) + ``` + """ + + def __init__( + self, + model_id: str, + api_key: str | None = None, + base_url: str | None = None, + client: OpenAI | None = None, + ) -> None: + self._model_id = model_id + self._client = client or OpenAI( + api_key=api_key, + base_url=base_url, + ) + + @cached_property + def _get_model(self) -> OpenAIGetModel: + """Lazily initialise the `OpenAIGetModel` on first use.""" + return OpenAIGetModel(model_id=self._model_id, client=self._client) + + @override + def query( + self, + query: str, + source: Source, + response_schema: Type[ResponseSchema] | None, + get_settings: GetSettings, + ) -> ResponseSchema | str: + result: ResponseSchema | str = self._get_model.get( + query=query, + source=source, + response_schema=response_schema, + get_settings=get_settings, + ) + return result diff --git a/src/askui/model_providers/openai_vlm_provider.py b/src/askui/model_providers/openai_vlm_provider.py new file mode 100644 index 00000000..47475cc7 --- /dev/null +++ b/src/askui/model_providers/openai_vlm_provider.py @@ -0,0 +1,121 @@ +"""OpenAIVlmProvider — VLM access via any OpenAI-compatible API.""" + +import os +from functools import cached_property +from typing import Any + +from openai import OpenAI +from typing_extensions import override + +from askui.model_providers.vlm_provider import VlmProvider +from askui.models.openai.messages_api import OpenAIMessagesApi +from askui.models.shared.agent_message_param import ( + MessageParam, + ThinkingConfigParam, + ToolChoiceParam, +) +from askui.models.shared.prompts import SystemPrompt +from askui.models.shared.tools import ToolCollection +from askui.utils.model_pricing import ModelPricing + +_DEFAULT_MODEL_ID = "gpt-5.4" + + +class OpenAIVlmProvider(VlmProvider): + """VLM provider for any OpenAI-compatible API. + + Works with OpenAI, Ollama, vLLM, LM Studio, Together AI, and any + other service that exposes an OpenAI-compatible ``/v1/chat/completions`` + endpoint. + + Args: + model_id (str): Model name to use. + api_key (str | None, optional): API key. Reads ``OPENAI_API_KEY`` + from the environment if not provided. + base_url (str | None, optional): Base URL for the API. Defaults + to the OpenAI API (``https://api.openai.com/v1``). + client (`OpenAI` | None, optional): Pre-configured OpenAI client. + If provided, ``api_key`` and ``base_url`` are ignored. + + Example: + ```python + from askui import AgentSettings, ComputerAgent + from askui.model_providers import OpenAIVlmProvider + + agent = ComputerAgent(settings=AgentSettings( + vlm_provider=OpenAIVlmProvider( + model_id="gpt-4o", + api_key="sk-...", + ) + )) + ``` + """ + + def __init__( + self, + model_id: str | None = None, + api_key: str | None = None, + base_url: str | None = None, + client: OpenAI | None = None, + input_cost_per_million_tokens: float | None = None, + output_cost_per_million_tokens: float | None = None, + cache_write_cost_per_million_tokens: float | None = None, + cache_read_cost_per_million_tokens: float | None = None, + ) -> None: + self._model_id_value = ( + model_id or os.environ.get("VLM_PROVIDER_MODEL_ID") or _DEFAULT_MODEL_ID + ) + if client is not None: + self._client = client + else: + self._client = OpenAI( + api_key=api_key, + base_url=base_url, + ) + + self._pricing = ModelPricing.for_model( + self._model_id_value, + input_cost_per_million_tokens=input_cost_per_million_tokens, + output_cost_per_million_tokens=output_cost_per_million_tokens, + cache_write_cost_per_million_tokens=cache_write_cost_per_million_tokens, + cache_read_cost_per_million_tokens=cache_read_cost_per_million_tokens, + ) + + @property + @override + def model_id(self) -> str: + return self._model_id_value + + @property + @override + def pricing(self) -> ModelPricing | None: + return self._pricing + + @cached_property + def _messages_api(self) -> OpenAIMessagesApi: + """Lazily initialise the `OpenAIMessagesApi` on first use.""" + return OpenAIMessagesApi(client=self._client) + + @override + def create_message( + self, + messages: list[MessageParam], + tools: ToolCollection | None = None, + max_tokens: int | None = None, + system: SystemPrompt | None = None, + thinking: ThinkingConfigParam | None = None, + tool_choice: ToolChoiceParam | None = None, + temperature: float | None = None, + provider_options: dict[str, Any] | None = None, + ) -> MessageParam: + return self._messages_api.create_message( + messages=messages, + model_id=self._model_id_value, + tools=tools, + max_tokens=max_tokens, + system=system, + thinking=thinking, + tool_choice=tool_choice, + temperature=temperature, + provider_options=provider_options, + ) diff --git a/src/askui/models/openai/__init__.py b/src/askui/models/openai/__init__.py new file mode 100644 index 00000000..c79f4b94 --- /dev/null +++ b/src/askui/models/openai/__init__.py @@ -0,0 +1 @@ +"""Model integration via OpenAI-compatible APIs.""" diff --git a/src/askui/models/openai/get_model.py b/src/askui/models/openai/get_model.py new file mode 100644 index 00000000..f19f2892 --- /dev/null +++ b/src/askui/models/openai/get_model.py @@ -0,0 +1,169 @@ +"""OpenAIGetModel — GetModel for any OpenAI-compatible API.""" + +import json +import logging +from typing import TYPE_CHECKING, Any, Type + +import openai +from openai import OpenAI +from typing_extensions import override + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletion + +from askui.models.exceptions import QueryNoResponseError +from askui.models.models import GetModel, GetSettings +from askui.models.shared.prompts import GetSystemPrompt +from askui.models.types.response_schemas import ResponseSchema, to_response_schema +from askui.prompts.get_prompts import SYSTEM_PROMPT_GET +from askui.utils.excel_utils import OfficeDocumentSource +from askui.utils.pdf_utils import PdfSource +from askui.utils.source_utils import Source + +logger = logging.getLogger(__name__) + + +def _clean_schema_refs(schema: dict[str, Any] | list[Any]) -> None: + """Remove ``title`` fields next to ``$ref`` fields (unsupported by OpenAI).""" + if isinstance(schema, dict): + if "$ref" in schema and "title" in schema: + del schema["title"] + for value in schema.values(): + if isinstance(value, (dict, list)): + _clean_schema_refs(value) + elif isinstance(schema, list): + for item in schema: + if isinstance(item, (dict, list)): + _clean_schema_refs(item) + + +class OpenAIGetModel(GetModel): + """GetModel implementation for any OpenAI-compatible API. + + Args: + model_id (str): The model name to use. + client (`OpenAI`): A pre-configured OpenAI client. + + Example: + ```python + from openai import OpenAI + from askui.models.openai.get_model import OpenAIGetModel + + client = OpenAI(api_key="sk-...") + model = OpenAIGetModel(model_id="gpt-4o", client=client) + ``` + """ + + def __init__( + self, + model_id: str, + client: OpenAI, + ) -> None: + self._model_id = model_id + self._client = client + + def _predict( + self, + image_url: str, + instruction: str, + prompt: GetSystemPrompt, + response_schema: type[ResponseSchema] | None, + ) -> str | None | ResponseSchema: + _response_schema = ( + to_response_schema(response_schema) if response_schema else None + ) + + response_format: openai.NotGiven | dict[str, Any] = openai.NOT_GIVEN + if _response_schema is not None: + schema = _response_schema.model_json_schema() + _clean_schema_refs(schema) + + defs = schema.pop("$defs", None) + schema_response_wrapper: dict[str, Any] = { + "type": "object", + "properties": {"response": schema}, + "additionalProperties": False, + "required": ["response"], + } + if defs: + schema_response_wrapper["$defs"] = defs + response_format = { + "type": "json_schema", + "json_schema": { + "name": "user_json_schema", + "schema": schema_response_wrapper, + "strict": True, + }, + } + + chat_completion: ChatCompletion = self._client.chat.completions.create( # type: ignore[call-overload] + model=self._model_id, + response_format=response_format, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + }, + {"type": "text", "text": str(prompt) + instruction}, + ], + } + ], + stream=False, + timeout=300.0, + ) + + model_response = chat_completion.choices[0].message.content + + if _response_schema is not None and model_response is not None: + try: + response_json = json.loads(model_response) + except json.JSONDecodeError: + error_msg = ( + f"Expected JSON, but model {self._model_id} " + f"returned: {model_response}" + ) + logger.exception( + "Expected JSON, but model returned", + extra={"model": self._model_id, "response": model_response}, + ) + raise ValueError(error_msg) from None + + validated_response = _response_schema.model_validate( + response_json["response"] + ) + return validated_response.root + + return model_response + + @override + def get( + self, + query: str, + source: Source, + response_schema: Type[ResponseSchema] | None, + get_settings: GetSettings, + ) -> ResponseSchema | str: + if isinstance(source, (PdfSource, OfficeDocumentSource)): + err_msg = ( + "PDF or Office Document processing is not supported" + " for OpenAI-compatible models" + ) + raise NotImplementedError(err_msg) + + system_prompt = get_settings.system_prompt or SYSTEM_PROMPT_GET + + response = self._predict( + image_url=source.to_data_url(), + instruction=query, + prompt=system_prompt, + response_schema=response_schema, + ) + if response is None: + error_msg = f'No response from model "{self._model_id}" to query: "{query}"' + raise QueryNoResponseError(error_msg, query) + return response diff --git a/src/askui/models/openai/messages_api.py b/src/askui/models/openai/messages_api.py new file mode 100644 index 00000000..ccd7e7e8 --- /dev/null +++ b/src/askui/models/openai/messages_api.py @@ -0,0 +1,332 @@ +"""OpenAIMessagesApi — MessagesApi for any OpenAI-compatible API.""" + +import json +import logging +from typing import Any + +from openai import OpenAI +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, +) +from typing_extensions import override + +from askui.models.shared.agent_message_param import ( + Base64ImageSourceParam, + BetaRedactedThinkingBlock, + BetaThinkingBlock, + ContentBlockParam, + ImageBlockParam, + MessageParam, + StopReason, + TextBlockParam, + ThinkingConfigParam, + ToolChoiceParam, + ToolResultBlockParam, + ToolUseBlockParam, + UsageParam, +) +from askui.models.shared.messages_api import MessagesApi +from askui.models.shared.prompts import SystemPrompt +from askui.models.shared.tools import ToolCollection + +logger = logging.getLogger(__name__) + +_FINISH_REASON_MAP: dict[str, StopReason] = { + "stop": "end_turn", + "length": "max_tokens", + "tool_calls": "tool_use", + "content_filter": "refusal", +} + + +def _map_finish_reason(finish_reason: str | None) -> StopReason | None: + """Map an OpenAI ``finish_reason`` to the internal `StopReason`.""" + if finish_reason is None: + return None + return _FINISH_REASON_MAP.get(finish_reason, "end_turn") + + +def _image_block_to_openai(block: ImageBlockParam) -> dict[str, Any]: + """Convert an `ImageBlockParam` to an OpenAI ``image_url`` content part.""" + if isinstance(block.source, Base64ImageSourceParam): + url = f"data:{block.source.media_type};base64,{block.source.data}" + else: + url = block.source.url + return {"type": "image_url", "image_url": {"url": url}} + + +def _serialize_tool_result_content( + content: str | list[TextBlockParam | ImageBlockParam], +) -> tuple[str, list[dict[str, Any]]]: + """Serialize ``ToolResultBlockParam.content`` for OpenAI's ``tool`` role. + + Returns the text portion as a string and any images as OpenAI content + parts (to be appended as a separate ``user`` message since the OpenAI + ``tool`` role only accepts string content). + """ + if isinstance(content, str): + return content, [] + + text_parts: list[str] = [] + image_parts: list[dict[str, Any]] = [] + for block in content: + if isinstance(block, TextBlockParam): + text_parts.append(block.text) + else: + image_parts.append(_image_block_to_openai(block)) + + return "\n".join(text_parts), image_parts + + +def _content_block_to_openai(block: ContentBlockParam) -> dict[str, Any] | None: + """Convert a user-message content block to an OpenAI content part. + + Returns ``None`` for block types that should be skipped (e.g. thinking). + """ + if isinstance(block, TextBlockParam): + return {"type": "text", "text": block.text} + if isinstance(block, ImageBlockParam): + return _image_block_to_openai(block) + if isinstance(block, (BetaThinkingBlock, BetaRedactedThinkingBlock)): + return None + return None + + +def _to_openai_messages( + messages: list[MessageParam], + system: SystemPrompt | None = None, +) -> list[dict[str, Any]]: + """Convert internal ``MessageParam`` list to OpenAI chat messages. + + Handles: + - System prompt prepended as a ``system`` role message + - User messages with text/image content parts + - Assistant messages with text content and ``tool_calls`` + - Tool result messages converted to ``tool`` role + - Images inside tool results appended as a follow-up ``user`` message + - Thinking blocks silently skipped + """ + result: list[dict[str, Any]] = [] + + if system is not None: + result.append({"role": "system", "content": str(system)}) + + for message in messages: + if isinstance(message.content, str): + result.append({"role": message.role, "content": message.content}) + continue + + if message.role == "assistant": + _convert_assistant_message(message.content, result) + else: + _convert_user_message(message.content, result) + + return result + + +def _convert_assistant_message( + blocks: list[ContentBlockParam], + result: list[dict[str, Any]], +) -> None: + """Convert an assistant message's content blocks to OpenAI format.""" + text_parts: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + + for block in blocks: + if isinstance(block, TextBlockParam): + text_parts.append(block.text) + elif isinstance(block, ToolUseBlockParam): + tool_calls.append( + { + "id": block.id, + "type": "function", + "function": { + "name": block.name, + "arguments": json.dumps(block.input), + }, + } + ) + # Skip thinking blocks silently + + openai_msg: dict[str, Any] = {"role": "assistant"} + content_text = "\n".join(text_parts) if text_parts else None + openai_msg["content"] = content_text + if tool_calls: + openai_msg["tool_calls"] = tool_calls + result.append(openai_msg) + + +def _convert_user_message( + blocks: list[ContentBlockParam], + result: list[dict[str, Any]], +) -> None: + """Convert a user message's content blocks to OpenAI format. + + ``ToolResultBlockParam`` blocks become ``tool`` role messages. + Images inside tool results are collected and appended as a separate + ``user`` message so the model can still see them. + """ + tool_result_images: list[dict[str, Any]] = [] + content_parts: list[dict[str, Any]] = [] + + for block in blocks: + if isinstance(block, ToolResultBlockParam): + text_content, images = _serialize_tool_result_content(block.content) + tool_result_images.extend(images) + result.append( + { + "role": "tool", + "tool_call_id": block.tool_use_id, + "content": text_content, + } + ) + else: + part = _content_block_to_openai(block) + if part is not None: + content_parts.append(part) + + if content_parts: + result.append({"role": "user", "content": content_parts}) + + # Append images from tool results as a separate user message + if tool_result_images: + result.append({"role": "user", "content": tool_result_images}) + + +def _to_openai_tools(tools: ToolCollection) -> list[dict[str, Any]]: + """Convert a `ToolCollection` to OpenAI function-calling tool format. + + Strips ``cache_control`` (Anthropic-specific) from tool parameters. + """ + result: list[dict[str, Any]] = [] + for tool_param in tools.to_params(): + schema = dict(tool_param.get("input_schema", {})) + schema.pop("cache_control", None) + func: dict[str, Any] = { + "name": tool_param["name"], + "parameters": schema, + } + if "description" in tool_param: + func["description"] = tool_param["description"] + result.append({"type": "function", "function": func}) + return result + + +def _parse_tool_calls( + message: ChatCompletionMessage, + content_blocks: list[ContentBlockParam], +) -> None: + """Extract tool calls from the OpenAI response and append as `ToolUseBlockParam`.""" + if not message.tool_calls: + return + for tool_call in message.tool_calls: + if not isinstance(tool_call, ChatCompletionMessageToolCall): + continue + content_blocks.append( + ToolUseBlockParam( + id=tool_call.id, + name=tool_call.function.name, + input=json.loads(tool_call.function.arguments), + ) + ) + + +def _from_openai_response(response: ChatCompletion) -> MessageParam: + """Convert an OpenAI ``ChatCompletion`` to an internal `MessageParam`.""" + choice = response.choices[0] + message = choice.message + stop_reason = _map_finish_reason(choice.finish_reason) + + content_blocks: list[ContentBlockParam] = [] + + if message.content: + content_blocks.append(TextBlockParam(text=message.content)) + + _parse_tool_calls(message, content_blocks) + + usage: UsageParam | None = None + if response.usage: + cached_tokens: int | None = None + if response.usage.prompt_tokens_details is not None: + cached_tokens = response.usage.prompt_tokens_details.cached_tokens + usage = UsageParam( + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + cache_read_input_tokens=cached_tokens, + ) + + # Simple string content when there's only a single text block + if len(content_blocks) == 1 and isinstance(content_blocks[0], TextBlockParam): + return MessageParam( + role="assistant", + content=content_blocks[0].text, + stop_reason=stop_reason, + usage=usage, + ) + + return MessageParam( + role="assistant", + content=content_blocks, + stop_reason=stop_reason, + usage=usage, + ) + + +class OpenAIMessagesApi(MessagesApi): + """MessagesApi implementation for any OpenAI-compatible chat API.""" + + def __init__(self, client: OpenAI) -> None: + self._client = client + + @override + def create_message( + self, + messages: list[MessageParam], + model_id: str, + tools: ToolCollection | None = None, + max_tokens: int | None = None, + system: SystemPrompt | None = None, + thinking: ThinkingConfigParam | None = None, # noqa: ARG002 + tool_choice: ToolChoiceParam | None = None, # noqa: ARG002 + temperature: float | None = None, + provider_options: dict[str, Any] | None = None, # noqa: ARG002 + ) -> MessageParam: + """Create a message via an OpenAI-compatible chat completions endpoint. + + Args: + messages: The conversation history. + model_id: The model name (e.g. ``"gpt-4o"``, ``"qwen2.5vl"``). + tools: Tools available to the model for function-calling. + max_tokens: Maximum tokens to generate. + system: System prompt. + thinking: Ignored (not supported by the OpenAI chat API). + tool_choice: Ignored. + temperature: Sampling temperature. + provider_options: Ignored. + + Returns: + The model's response as a `MessageParam`. + """ + openai_messages = _to_openai_messages(messages, system) + + kwargs: dict[str, Any] = { + "model": model_id, + "messages": openai_messages, + "stream": False, + "timeout": 300.0, + } + + if max_tokens is not None: + kwargs["max_completion_tokens"] = max_tokens + + if temperature is not None: + kwargs["temperature"] = temperature + + if tools is not None: + openai_tools = _to_openai_tools(tools) + if openai_tools: + kwargs["tools"] = openai_tools + + response = self._client.chat.completions.create(**kwargs) + return _from_openai_response(response) diff --git a/src/askui/utils/model_pricing.py b/src/askui/utils/model_pricing.py index fceea324..4f7765a8 100644 --- a/src/askui/utils/model_pricing.py +++ b/src/askui/utils/model_pricing.py @@ -110,5 +110,23 @@ def for_model( cache_write_cost_per_million_tokens=6.25, cache_read_cost_per_million_tokens=0.50, ), + "gpt-5.4": ModelPricing( + input_cost_per_million_tokens=2.5, + output_cost_per_million_tokens=15.0, + cache_write_cost_per_million_tokens=2.5, + cache_read_cost_per_million_tokens=0.25, + ), + "gpt-5.4-mini": ModelPricing( + input_cost_per_million_tokens=0.75, + output_cost_per_million_tokens=4.50, + cache_write_cost_per_million_tokens=0.75, + cache_read_cost_per_million_tokens=0.0075, + ), + "gpt-5.4-nano": ModelPricing( + input_cost_per_million_tokens=0.20, + output_cost_per_million_tokens=1.25, + cache_write_cost_per_million_tokens=0.20, + cache_read_cost_per_million_tokens=0.02, + ), } ) diff --git a/tests/unit/model_providers/test_ollama_image_qa_provider.py b/tests/unit/model_providers/test_ollama_image_qa_provider.py new file mode 100644 index 00000000..a9359f39 --- /dev/null +++ b/tests/unit/model_providers/test_ollama_image_qa_provider.py @@ -0,0 +1,64 @@ +"""Unit tests for OllamaImageQAProvider.""" + +from unittest.mock import MagicMock + +from openai import OpenAI +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types.completion_usage import CompletionUsage + +from askui.model_providers.ollama_image_qa_provider import OllamaImageQAProvider +from askui.models.shared.settings import GetSettings +from askui.utils.image_utils import ImageSource + + +def _make_completion(content: str) -> ChatCompletion: + return ChatCompletion( + id="chatcmpl-test", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=content), + ) + ], + created=1234567890, + model="qwen2.5vl", + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + + +class TestOllamaImageQAProvider: + def test_default_model_id(self) -> None: + provider = OllamaImageQAProvider() + assert provider._model_id == "qwen2.5vl" + + def test_custom_model_id(self) -> None: + provider = OllamaImageQAProvider(model_id="llava") + assert provider._model_id == "llava" + + def test_injected_client_used(self) -> None: + mock_client = MagicMock(spec=OpenAI) + provider = OllamaImageQAProvider(client=mock_client) + assert provider._client is mock_client + + def test_query_delegates_to_get_model(self) -> None: + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _make_completion( + "The button" + ) + + source = MagicMock(spec=ImageSource) + source.to_data_url.return_value = "data:image/png;base64,abc" + + provider = OllamaImageQAProvider(model_id="qwen2.5vl", client=mock_client) + result = provider.query( + query="What is this?", + source=source, + response_schema=None, + get_settings=GetSettings(), + ) + + assert result == "The button" + mock_client.chat.completions.create.assert_called_once() diff --git a/tests/unit/model_providers/test_ollama_vlm_provider.py b/tests/unit/model_providers/test_ollama_vlm_provider.py new file mode 100644 index 00000000..70bc449f --- /dev/null +++ b/tests/unit/model_providers/test_ollama_vlm_provider.py @@ -0,0 +1,50 @@ +"""Unit tests for OllamaVlmProvider.""" + +from unittest.mock import MagicMock + +from openai import OpenAI + +from askui.model_providers.ollama_vlm_provider import OllamaVlmProvider +from askui.models.shared.agent_message_param import MessageParam + + +class TestOllamaVlmProvider: + def test_default_model_id(self) -> None: + provider = OllamaVlmProvider() + assert provider.model_id == "qwen2.5vl" + + def test_custom_model_id(self) -> None: + provider = OllamaVlmProvider(model_id="llava") + assert provider.model_id == "llava" + + def test_pricing_returns_none(self) -> None: + provider = OllamaVlmProvider() + assert provider.pricing is None + + def test_injected_client_used(self) -> None: + mock_client = MagicMock(spec=OpenAI) + provider = OllamaVlmProvider(client=mock_client) + assert provider._client is mock_client + + def test_create_message_delegates_to_messages_api(self) -> None: + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = MagicMock( + choices=[ + MagicMock( + finish_reason="stop", + message=MagicMock(content="done", tool_calls=None), + ) + ], + usage=MagicMock(prompt_tokens=5, completion_tokens=10), + ) + + provider = OllamaVlmProvider( + model_id="test-model", + client=mock_client, + ) + result = provider.create_message( + messages=[MessageParam(role="user", content="hi")], + ) + + mock_client.chat.completions.create.assert_called_once() + assert result.role == "assistant" diff --git a/tests/unit/model_providers/test_openai_image_qa_provider.py b/tests/unit/model_providers/test_openai_image_qa_provider.py new file mode 100644 index 00000000..e9c748ea --- /dev/null +++ b/tests/unit/model_providers/test_openai_image_qa_provider.py @@ -0,0 +1,56 @@ +"""Unit tests for OpenAIImageQAProvider.""" + +from unittest.mock import MagicMock + +from openai import OpenAI +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types.completion_usage import CompletionUsage + +from askui.model_providers.openai_image_qa_provider import OpenAIImageQAProvider +from askui.models.shared.settings import GetSettings +from askui.utils.image_utils import ImageSource + + +def _make_completion(content: str) -> ChatCompletion: + return ChatCompletion( + id="chatcmpl-test", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=content), + ) + ], + created=1234567890, + model="gpt-4o", + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + + +class TestOpenAIImageQAProvider: + def test_injected_client_used(self) -> None: + mock_client = MagicMock(spec=OpenAI) + provider = OpenAIImageQAProvider(model_id="gpt-4o", client=mock_client) + assert provider._client is mock_client + + def test_query_delegates_to_get_model(self) -> None: + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _make_completion( + "The button" + ) + + source = MagicMock(spec=ImageSource) + source.to_data_url.return_value = "data:image/png;base64,abc" + + provider = OpenAIImageQAProvider(model_id="gpt-4o", client=mock_client) + result = provider.query( + query="What is this?", + source=source, + response_schema=None, + get_settings=GetSettings(), + ) + + assert result == "The button" + mock_client.chat.completions.create.assert_called_once() diff --git a/tests/unit/model_providers/test_openai_vlm_provider.py b/tests/unit/model_providers/test_openai_vlm_provider.py new file mode 100644 index 00000000..d51ff74b --- /dev/null +++ b/tests/unit/model_providers/test_openai_vlm_provider.py @@ -0,0 +1,43 @@ +"""Unit tests for OpenAIVlmProvider.""" + +from unittest.mock import MagicMock + +from openai import OpenAI + +from askui.model_providers.openai_vlm_provider import OpenAIVlmProvider +from askui.models.shared.agent_message_param import MessageParam + + +class TestOpenAIVlmProvider: + def test_model_id(self) -> None: + provider = OpenAIVlmProvider(model_id="gpt-4o", api_key="sk-test") + assert provider.model_id == "gpt-4o" + + def test_pricing_returns_none(self) -> None: + provider = OpenAIVlmProvider(model_id="gpt-4o", api_key="sk-test") + assert provider.pricing is None + + def test_injected_client_used(self) -> None: + mock_client = MagicMock(spec=OpenAI) + provider = OpenAIVlmProvider(model_id="gpt-4o", client=mock_client) + assert provider._client is mock_client + + def test_create_message_delegates_to_messages_api(self) -> None: + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = MagicMock( + choices=[ + MagicMock( + finish_reason="stop", + message=MagicMock(content="done", tool_calls=None), + ) + ], + usage=MagicMock(prompt_tokens=5, completion_tokens=10), + ) + + provider = OpenAIVlmProvider(model_id="gpt-4o", client=mock_client) + result = provider.create_message( + messages=[MessageParam(role="user", content="hi")], + ) + + mock_client.chat.completions.create.assert_called_once() + assert result.role == "assistant" diff --git a/tests/unit/models/openai/__init__.py b/tests/unit/models/openai/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/models/openai/test_get_model.py b/tests/unit/models/openai/test_get_model.py new file mode 100644 index 00000000..fef21b9b --- /dev/null +++ b/tests/unit/models/openai/test_get_model.py @@ -0,0 +1,96 @@ +"""Unit tests for OpenAIGetModel.""" + +from unittest.mock import MagicMock + +import pytest +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types.completion_usage import CompletionUsage + +from askui.models.exceptions import QueryNoResponseError +from askui.models.openai.get_model import OpenAIGetModel +from askui.models.shared.settings import GetSettings +from askui.utils.excel_utils import OfficeDocumentSource +from askui.utils.image_utils import ImageSource +from askui.utils.pdf_utils import PdfSource + + +def _make_completion(content: str | None) -> ChatCompletion: + return ChatCompletion( + id="chatcmpl-test", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=content), + ) + ], + created=1234567890, + model="qwen2.5vl", + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + + +class TestOpenAIGetModel: + def test_basic_query_returns_string(self) -> None: + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _make_completion( + "The button says Submit" + ) + + source = MagicMock(spec=ImageSource) + source.to_data_url.return_value = "data:image/png;base64,abc" + + model = OpenAIGetModel(model_id="qwen2.5vl", client=mock_client) + result = model.get( + query="What does the button say?", + source=source, + response_schema=None, + get_settings=GetSettings(), + ) + + assert result == "The button says Submit" + mock_client.chat.completions.create.assert_called_once() + + def test_no_response_raises_error(self) -> None: + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _make_completion(None) + + source = MagicMock(spec=ImageSource) + source.to_data_url.return_value = "data:image/png;base64,abc" + + model = OpenAIGetModel(model_id="qwen2.5vl", client=mock_client) + with pytest.raises(QueryNoResponseError): + model.get( + query="Describe", + source=source, + response_schema=None, + get_settings=GetSettings(), + ) + + def test_pdf_source_not_supported(self) -> None: + mock_client = MagicMock() + source = MagicMock(spec=PdfSource) + + model = OpenAIGetModel(model_id="qwen2.5vl", client=mock_client) + with pytest.raises(NotImplementedError, match="PDF or Office Document"): + model.get( + query="Describe", + source=source, + response_schema=None, + get_settings=GetSettings(), + ) + + def test_office_document_source_not_supported(self) -> None: + mock_client = MagicMock() + source = MagicMock(spec=OfficeDocumentSource) + + model = OpenAIGetModel(model_id="qwen2.5vl", client=mock_client) + with pytest.raises(NotImplementedError, match="PDF or Office Document"): + model.get( + query="Describe", + source=source, + response_schema=None, + get_settings=GetSettings(), + ) diff --git a/tests/unit/models/openai/test_messages_api.py b/tests/unit/models/openai/test_messages_api.py new file mode 100644 index 00000000..22fbcbce --- /dev/null +++ b/tests/unit/models/openai/test_messages_api.py @@ -0,0 +1,443 @@ +"""Unit tests for OpenAI messages API conversion functions.""" + +from unittest.mock import MagicMock + +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, +) +from openai.types.completion_usage import CompletionUsage + +from askui.models.openai.messages_api import ( + OpenAIMessagesApi, + _from_openai_response, + _image_block_to_openai, + _map_finish_reason, + _serialize_tool_result_content, + _to_openai_messages, + _to_openai_tools, +) +from askui.models.shared.agent_message_param import ( + Base64ImageSourceParam, + BetaRedactedThinkingBlock, + BetaThinkingBlock, + ImageBlockParam, + MessageParam, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlockParam, + UrlImageSourceParam, +) +from askui.models.shared.prompts import SystemPrompt + + +def _make_completion( + content: str | None = None, + tool_calls: list[ChatCompletionMessageToolCall] | None = None, + finish_reason: str = "stop", + prompt_tokens: int = 10, + completion_tokens: int = 20, +) -> ChatCompletion: + """Create a mock ChatCompletion response.""" + return ChatCompletion( + id="chatcmpl-test", + choices=[ + Choice( + finish_reason=finish_reason, + index=0, + message=ChatCompletionMessage( + role="assistant", + content=content, + tool_calls=tool_calls, + ), + ) + ], + created=1234567890, + model="qwen2.5vl", + object="chat.completion", + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + +class TestMapFinishReason: + def test_stop_maps_to_end_turn(self) -> None: + assert _map_finish_reason("stop") == "end_turn" + + def test_length_maps_to_max_tokens(self) -> None: + assert _map_finish_reason("length") == "max_tokens" + + def test_tool_calls_maps_to_tool_use(self) -> None: + assert _map_finish_reason("tool_calls") == "tool_use" + + def test_content_filter_maps_to_refusal(self) -> None: + assert _map_finish_reason("content_filter") == "refusal" + + def test_none_returns_none(self) -> None: + assert _map_finish_reason(None) is None + + def test_unknown_falls_back_to_end_turn(self) -> None: + assert _map_finish_reason("unknown_reason") == "end_turn" + + +class TestImageBlockToOpenai: + def test_base64_image(self) -> None: + block = ImageBlockParam( + source=Base64ImageSourceParam(data="aWltYWdl", media_type="image/png") + ) + result = _image_block_to_openai(block) + assert result == { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,aWltYWdl"}, + } + + def test_url_image(self) -> None: + block = ImageBlockParam( + source=UrlImageSourceParam(url="https://example.com/img.png") + ) + result = _image_block_to_openai(block) + assert result == { + "type": "image_url", + "image_url": {"url": "https://example.com/img.png"}, + } + + +class TestSerializeToolResultContent: + def test_string_content(self) -> None: + text, images = _serialize_tool_result_content("hello") + assert text == "hello" + assert images == [] + + def test_text_blocks(self) -> None: + content: list[TextBlockParam | ImageBlockParam] = [ + TextBlockParam(text="line1"), + TextBlockParam(text="line2"), + ] + text, images = _serialize_tool_result_content(content) + assert text == "line1\nline2" + assert images == [] + + def test_image_blocks_extracted(self) -> None: + content: list[TextBlockParam | ImageBlockParam] = [ + TextBlockParam(text="screenshot"), + ImageBlockParam( + source=Base64ImageSourceParam(data="abc", media_type="image/png") + ), + ] + text, images = _serialize_tool_result_content(content) + assert text == "screenshot" + assert len(images) == 1 + assert images[0]["type"] == "image_url" + + +class TestToOpenaiMessages: + def test_simple_text_message(self) -> None: + messages = [MessageParam(role="user", content="hello")] + result = _to_openai_messages(messages) + assert result == [{"role": "user", "content": "hello"}] + + def test_system_prompt_prepended(self) -> None: + system = SystemPrompt(prompt="Be helpful.") + messages = [MessageParam(role="user", content="hi")] + result = _to_openai_messages(messages, system) + assert result[0] == {"role": "system", "content": "Be helpful."} + assert result[1] == {"role": "user", "content": "hi"} + + def test_user_message_with_image(self) -> None: + messages = [ + MessageParam( + role="user", + content=[ + TextBlockParam(text="What is this?"), + ImageBlockParam( + source=Base64ImageSourceParam( + data="abc123", media_type="image/png" + ) + ), + ], + ) + ] + result = _to_openai_messages(messages) + assert len(result) == 1 + assert result[0]["role"] == "user" + parts = result[0]["content"] + assert len(parts) == 2 + assert parts[0] == {"type": "text", "text": "What is this?"} + assert parts[1]["type"] == "image_url" + + def test_assistant_message_with_tool_calls(self) -> None: + messages = [ + MessageParam( + role="assistant", + content=[ + TextBlockParam(text="I'll take a screenshot."), + ToolUseBlockParam( + id="call_1", + name="screenshot", + input={}, + ), + ], + ) + ] + result = _to_openai_messages(messages) + assert len(result) == 1 + msg = result[0] + assert msg["role"] == "assistant" + assert msg["content"] == "I'll take a screenshot." + assert len(msg["tool_calls"]) == 1 + tc = msg["tool_calls"][0] + assert tc["id"] == "call_1" + assert tc["function"]["name"] == "screenshot" + assert tc["function"]["arguments"] == "{}" + + def test_tool_result_message(self) -> None: + messages = [ + MessageParam( + role="user", + content=[ + ToolResultBlockParam( + tool_use_id="call_1", + content="Done", + ), + ], + ) + ] + result = _to_openai_messages(messages) + assert len(result) == 1 + assert result[0] == { + "role": "tool", + "tool_call_id": "call_1", + "content": "Done", + } + + def test_tool_result_with_images_appended_as_user_message(self) -> None: + messages = [ + MessageParam( + role="user", + content=[ + ToolResultBlockParam( + tool_use_id="call_1", + content=[ + TextBlockParam(text="Screenshot taken"), + ImageBlockParam( + source=Base64ImageSourceParam( + data="img", media_type="image/png" + ) + ), + ], + ), + ], + ) + ] + result = _to_openai_messages(messages) + # Should produce: tool message + user message with image + assert len(result) == 2 + assert result[0]["role"] == "tool" + assert result[0]["content"] == "Screenshot taken" + assert result[1]["role"] == "user" + assert result[1]["content"][0]["type"] == "image_url" + + def test_thinking_blocks_skipped(self) -> None: + messages = [ + MessageParam( + role="assistant", + content=[ + BetaThinkingBlock( + signature="sig", thinking="hmm...", type="thinking" + ), + TextBlockParam(text="The answer is 42."), + ], + ) + ] + result = _to_openai_messages(messages) + assert len(result) == 1 + assert result[0]["content"] == "The answer is 42." + assert "tool_calls" not in result[0] + + def test_redacted_thinking_blocks_skipped(self) -> None: + messages = [ + MessageParam( + role="assistant", + content=[ + BetaRedactedThinkingBlock( + data="redacted", type="redacted_thinking" + ), + TextBlockParam(text="Result."), + ], + ) + ] + result = _to_openai_messages(messages) + assert result[0]["content"] == "Result." + + +class TestToOpenaiTools: + def test_converts_tool_collection(self) -> None: + tool_collection = MagicMock() + tool_collection.to_params.return_value = [ + { + "name": "click", + "description": "Click an element", + "input_schema": { + "type": "object", + "properties": {"x": {"type": "integer"}}, + }, + } + ] + result = _to_openai_tools(tool_collection) + assert len(result) == 1 + assert result[0] == { + "type": "function", + "function": { + "name": "click", + "description": "Click an element", + "parameters": { + "type": "object", + "properties": {"x": {"type": "integer"}}, + }, + }, + } + + def test_strips_cache_control(self) -> None: + tool_collection = MagicMock() + tool_collection.to_params.return_value = [ + { + "name": "screenshot", + "description": "Take screenshot", + "input_schema": { + "type": "object", + "cache_control": {"type": "ephemeral"}, + }, + } + ] + result = _to_openai_tools(tool_collection) + assert "cache_control" not in result[0]["function"]["parameters"] + + +class TestFromOpenaiResponse: + def test_text_only_response(self) -> None: + completion = _make_completion(content="Hello!") + result = _from_openai_response(completion) + assert result.role == "assistant" + assert result.content == "Hello!" + assert result.stop_reason == "end_turn" + assert result.usage is not None + assert result.usage.input_tokens == 10 + assert result.usage.output_tokens == 20 + + def test_tool_call_response(self) -> None: + tool_calls = [ + ChatCompletionMessageToolCall( + id="call_1", + type="function", + function=Function( + name="click", + arguments='{"x": 100, "y": 200}', + ), + ) + ] + completion = _make_completion(tool_calls=tool_calls, finish_reason="tool_calls") + result = _from_openai_response(completion) + assert result.role == "assistant" + assert result.stop_reason == "tool_use" + assert isinstance(result.content, list) + assert len(result.content) == 1 + block = result.content[0] + assert isinstance(block, ToolUseBlockParam) + assert block.id == "call_1" + assert block.name == "click" + assert block.input == {"x": 100, "y": 200} + + def test_text_and_tool_calls(self) -> None: + tool_calls = [ + ChatCompletionMessageToolCall( + id="call_1", + type="function", + function=Function(name="screenshot", arguments="{}"), + ) + ] + completion = _make_completion( + content="Let me take a screenshot.", + tool_calls=tool_calls, + finish_reason="tool_calls", + ) + result = _from_openai_response(completion) + assert isinstance(result.content, list) + assert len(result.content) == 2 + assert isinstance(result.content[0], TextBlockParam) + assert isinstance(result.content[1], ToolUseBlockParam) + + def test_usage_captured(self) -> None: + completion = _make_completion( + content="ok", prompt_tokens=50, completion_tokens=100 + ) + result = _from_openai_response(completion) + assert result.usage is not None + assert result.usage.input_tokens == 50 + assert result.usage.output_tokens == 100 + + +class TestOpenAIMessagesApi: + def test_create_message_delegates_to_client(self) -> None: + mock_client = MagicMock() + completion = _make_completion(content="response") + mock_client.chat.completions.create.return_value = completion + + api = OpenAIMessagesApi(client=mock_client) + result = api.create_message( + messages=[MessageParam(role="user", content="hello")], + model_id="qwen2.5vl", + ) + + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["model"] == "qwen2.5vl" + assert call_kwargs["stream"] is False + assert result.content == "response" + + def test_tools_passed_when_provided(self) -> None: + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _make_completion( + content="ok" + ) + + mock_tools = MagicMock() + mock_tools.to_params.return_value = [ + { + "name": "click", + "description": "Click", + "input_schema": {"type": "object"}, + } + ] + + api = OpenAIMessagesApi(client=mock_client) + api.create_message( + messages=[MessageParam(role="user", content="hi")], + model_id="test", + tools=mock_tools, + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert "tools" in call_kwargs + assert call_kwargs["tools"][0]["type"] == "function" + + def test_optional_params_omitted_when_none(self) -> None: + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _make_completion( + content="ok" + ) + + api = OpenAIMessagesApi(client=mock_client) + api.create_message( + messages=[MessageParam(role="user", content="hi")], + model_id="test", + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert "max_tokens" not in call_kwargs + assert "temperature" not in call_kwargs + assert "tools" not in call_kwargs From 5729f0513f83a3c5aac438a79ccceff03626ecff Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 27 Apr 2026 10:48:47 +0200 Subject: [PATCH 2/3] fix: default model in ollama provider tests --- tests/unit/model_providers/test_ollama_image_qa_provider.py | 4 ++-- tests/unit/model_providers/test_ollama_vlm_provider.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/model_providers/test_ollama_image_qa_provider.py b/tests/unit/model_providers/test_ollama_image_qa_provider.py index a9359f39..84e5168e 100644 --- a/tests/unit/model_providers/test_ollama_image_qa_provider.py +++ b/tests/unit/model_providers/test_ollama_image_qa_provider.py @@ -23,7 +23,7 @@ def _make_completion(content: str) -> ChatCompletion: ) ], created=1234567890, - model="qwen2.5vl", + model="qwen3.5", object="chat.completion", usage=CompletionUsage(prompt_tokens=10, completion_tokens=20, total_tokens=30), ) @@ -32,7 +32,7 @@ def _make_completion(content: str) -> ChatCompletion: class TestOllamaImageQAProvider: def test_default_model_id(self) -> None: provider = OllamaImageQAProvider() - assert provider._model_id == "qwen2.5vl" + assert provider._model_id == "qwen3.5" def test_custom_model_id(self) -> None: provider = OllamaImageQAProvider(model_id="llava") diff --git a/tests/unit/model_providers/test_ollama_vlm_provider.py b/tests/unit/model_providers/test_ollama_vlm_provider.py index 70bc449f..143e7c35 100644 --- a/tests/unit/model_providers/test_ollama_vlm_provider.py +++ b/tests/unit/model_providers/test_ollama_vlm_provider.py @@ -11,7 +11,7 @@ class TestOllamaVlmProvider: def test_default_model_id(self) -> None: provider = OllamaVlmProvider() - assert provider.model_id == "qwen2.5vl" + assert provider.model_id == "qwen3.5" def test_custom_model_id(self) -> None: provider = OllamaVlmProvider(model_id="llava") From 91bc0d60544d005fa1ffe72c1aabd24b3f2a3aa9 Mon Sep 17 00:00:00 2001 From: philipph-askui Date: Mon, 27 Apr 2026 18:06:57 +0200 Subject: [PATCH 3/3] feat: add openAI-compatible vlm provider with custom endpoint (i/o baseurl) --- pyproject.toml | 1 + src/askui/model_providers/__init__.py | 5 ++ .../openai_compatible_vlm_provider.py | 59 ++++++++++++++ .../test_openai_compatible_vlm_provider.py | 77 +++++++++++++++++++ 4 files changed, 142 insertions(+) create mode 100644 src/askui/model_providers/openai_compatible_vlm_provider.py create mode 100644 tests/unit/model_providers/test_openai_compatible_vlm_provider.py diff --git a/pyproject.toml b/pyproject.toml index d8bfc7e4..fe16e57c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -205,6 +205,7 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" "tests/unit/locators/serializers/test_locator_string_representation.py" = ["E501"] "tests/unit/locators/test_locators.py" = ["E501"] "tests/unit/utils/test_image_utils.py" = ["E501"] +"playground.py" = ["F401", "E501"] [tool.ruff.lint.flake8-quotes] docstring-quotes = "double" diff --git a/src/askui/model_providers/__init__.py b/src/askui/model_providers/__init__.py index b804e473..ae1f0d0d 100644 --- a/src/askui/model_providers/__init__.py +++ b/src/askui/model_providers/__init__.py @@ -16,6 +16,7 @@ - `OpenAIImageQAProvider` — image Q&A via any OpenAI-compatible API - `OllamaVlmProvider` — VLM via local Ollama instance (OpenAI-compatible) - `OllamaImageQAProvider` — image Q&A via local Ollama instance (OpenAI-compatible) +- `OpenAICompatibleVlmProvider` — VLM via OpenAI-compatible API with fixed URL """ from askui.model_providers.anthropic_image_qa_provider import AnthropicImageQAProvider @@ -28,6 +29,9 @@ from askui.model_providers.image_qa_provider import ImageQAProvider from askui.model_providers.ollama_image_qa_provider import OllamaImageQAProvider from askui.model_providers.ollama_vlm_provider import OllamaVlmProvider +from askui.model_providers.openai_compatible_vlm_provider import ( + OpenAICompatibleVlmProvider, +) from askui.model_providers.openai_image_qa_provider import OpenAIImageQAProvider from askui.model_providers.openai_vlm_provider import OpenAIVlmProvider from askui.model_providers.vlm_provider import VlmProvider @@ -47,5 +51,6 @@ "OllamaVlmProvider", "OpenAIImageQAProvider", "OpenAIVlmProvider", + "OpenAICompatibleVlmProvider", "VlmProvider", ] diff --git a/src/askui/model_providers/openai_compatible_vlm_provider.py b/src/askui/model_providers/openai_compatible_vlm_provider.py new file mode 100644 index 00000000..aae55c11 --- /dev/null +++ b/src/askui/model_providers/openai_compatible_vlm_provider.py @@ -0,0 +1,59 @@ +"""OpenAICompatibleVlmProvider — VLM access via a fixed endpoint URL.""" + +import httpx +from openai import OpenAI + +from askui.model_providers.openai_vlm_provider import OpenAIVlmProvider + + +class OpenAICompatibleVlmProvider(OpenAIVlmProvider): + """VLM provider for OpenAI-compatible APIs that require an exact endpoint URL. + + The OpenAI SDK always appends ``/chat/completions`` to ``base_url``, + which breaks endpoints that already include the full path (e.g. RunPod, + custom proxies, serverless deployments). This provider works around + the issue by installing an httpx event hook that rewrites every + outgoing request URL to the exact ``endpoint_url``. + + Args: + endpoint_url (str): Full endpoint URL including the path + (e.g. ``"https://my-host/v1/chat/completions"``). + model_id (str): Model name expected by the deployment. + api_key (str | None, optional): API key for the endpoint. + + Example: + ```python + from askui import AgentSettings, ComputerAgent + from askui.model_providers import OpenAICompatibleVlmProvider + + agent = ComputerAgent(settings=AgentSettings( + vlm_provider=OpenAICompatibleVlmProvider( + endpoint_url="https://my-host/v1/chat/completions", + model_id="my-model", + api_key="...", + ) + )) + ``` + """ + + def __init__( + self, + endpoint_url: str, + model_id: str | None = None, + api_key: str | None = None, + ) -> None: + def _rewrite_url(request: httpx.Request) -> None: + request.url = httpx.URL(endpoint_url) + + http_client = httpx.Client(event_hooks={"request": [_rewrite_url]}) + + client = OpenAI( + api_key=api_key, + base_url=endpoint_url, + http_client=http_client, + ) + + super().__init__( + model_id=model_id, + client=client, + ) diff --git a/tests/unit/model_providers/test_openai_compatible_vlm_provider.py b/tests/unit/model_providers/test_openai_compatible_vlm_provider.py new file mode 100644 index 00000000..21670c5a --- /dev/null +++ b/tests/unit/model_providers/test_openai_compatible_vlm_provider.py @@ -0,0 +1,77 @@ +"""Unit tests for OpenAICompatibleVlmProvider.""" + +from unittest.mock import MagicMock + +import httpx + +from askui.model_providers.openai_compatible_vlm_provider import ( + OpenAICompatibleVlmProvider, +) +from askui.models.shared.agent_message_param import MessageParam + + +class TestOpenAICompatibleVlmProvider: + def test_model_id(self) -> None: + provider = OpenAICompatibleVlmProvider( + endpoint_url="https://my-host/v1/chat/completions", + model_id="my-model", + api_key="test-key", + ) + assert provider.model_id == "my-model" + + def test_pricing_returns_none(self) -> None: + provider = OpenAICompatibleVlmProvider( + endpoint_url="https://my-host/v1/chat/completions", + model_id="my-model", + api_key="test-key", + ) + assert provider.pricing is None + + def test_injected_client_is_openai_instance(self) -> None: + provider = OpenAICompatibleVlmProvider( + endpoint_url="https://my-host/v1/chat/completions", + model_id="my-model", + api_key="test-key", + ) + assert provider._client is not None + + def test_httpx_event_hook_rewrites_url(self) -> None: + endpoint_url = "https://my-host/v1/chat/completions" + provider = OpenAICompatibleVlmProvider( + endpoint_url=endpoint_url, + model_id="my-model", + api_key="test-key", + ) + + http_client: httpx.Client = provider._client._client + request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions") + for hook in http_client.event_hooks["request"]: + hook(request) + + assert str(request.url) == endpoint_url + + def test_create_message_delegates_to_messages_api(self) -> None: + provider = OpenAICompatibleVlmProvider( + endpoint_url="https://my-host/v1/chat/completions", + model_id="test-model", + api_key="test-key", + ) + + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = MagicMock( + choices=[ + MagicMock( + finish_reason="stop", + message=MagicMock(content="done", tool_calls=None), + ) + ], + usage=MagicMock(prompt_tokens=5, completion_tokens=10), + ) + provider._client = mock_client + + result = provider.create_message( + messages=[MessageParam(role="user", content="hi")], + ) + + mock_client.chat.completions.create.assert_called_once() + assert result.role == "assistant"