Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/typeagent/aitools/model_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,14 @@ def _make_azure_provider(
azure_ad_token_provider=token_provider.get_token,
)
else:
apim_key = os.getenv("AZURE_APIM_SUBSCRIPTION_KEY")
client = AsyncAzureOpenAI(
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=raw_key,
default_headers=(
{"Ocp-Apim-Subscription-Key": apim_key} if apim_key else None
),
)
Comment thread
bmerkle marked this conversation as resolved.
return AzureProvider(openai_client=client)

Expand Down
29 changes: 13 additions & 16 deletions src/typeagent/aitools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,10 @@ def parse_azure_endpoint(
f"{endpoint_envvar}={azure_endpoint} doesn't contain valid api-version field"
)

# Strip query string — AsyncAzureOpenAI expects a clean base URL and
# receives api_version as a separate parameter.
# Strip query string and /openai... path — AsyncAzureOpenAI expects a
# clean base URL and builds the deployment path internally.
clean_endpoint = azure_endpoint.split("?", 1)[0]
clean_endpoint = re.sub(r"/openai(/deployments/.*)?$", "", clean_endpoint)
Comment thread
gvanrossum marked this conversation as resolved.

return clean_endpoint, m.group(1)

Expand Down Expand Up @@ -254,10 +255,15 @@ def create_async_openai_client(
azure_api_key = get_azure_api_key(azure_api_key)
azure_endpoint, api_version = parse_azure_endpoint(endpoint_envvar)

apim_key = os.getenv("AZURE_APIM_SUBSCRIPTION_KEY")

return AsyncAzureOpenAI(
api_version=api_version,
azure_endpoint=azure_endpoint,
api_key=azure_api_key,
default_headers=(
{"Ocp-Apim-Subscription-Key": apim_key} if apim_key else None
),
)

else:
Expand All @@ -271,30 +277,21 @@ def make_agent[T](cls: type[T]):
"""Create Pydantic AI agent using hardcoded preferences."""
from pydantic_ai import Agent, NativeOutput, ToolOutput
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.azure import AzureProvider

# Prefer straight OpenAI over Azure OpenAI.
if os.getenv("OPENAI_API_KEY"):
Wrapper = NativeOutput
print(f"## Using OpenAI with {Wrapper.__name__} ##")
model = OpenAIChatModel("gpt-4o") # Retrieves OPENAI_API_KEY again.

elif azure_api_key := os.getenv("AZURE_OPENAI_API_KEY"):
azure_api_key = get_azure_api_key(azure_api_key)
azure_endpoint, api_version = parse_azure_endpoint("AZURE_OPENAI_ENDPOINT")
elif os.getenv("AZURE_OPENAI_API_KEY"):
from typeagent.aitools.model_adapters import _make_azure_provider

print(f"## {azure_endpoint} ##")
azure_provider = _make_azure_provider()
Wrapper = ToolOutput

print(f"## Using Azure {api_version} with {Wrapper.__name__} ##")
model = OpenAIChatModel(
"gpt-4o",
provider=AzureProvider(
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=azure_api_key,
),
)
print(f"## Using Azure with {Wrapper.__name__} ##")
model = OpenAIChatModel("gpt-4o", provider=azure_provider)

else:
raise RuntimeError(
Expand Down
40 changes: 14 additions & 26 deletions tests/test_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,24 @@

import pytest

from typeagent.aitools.utils import create_async_openai_client
import typechat

from typeagent.aitools.model_adapters import create_chat_model


@pytest.mark.asyncio
async def test_why_is_sky_blue(really_needs_auth: None):
"""Test that chat agent responds correctly to 'why is the sky blue?'"""

# Create an async OpenAI client
try:
client = create_async_openai_client()
except RuntimeError as e:
if "Neither OPENAI_API_KEY nor AZURE_OPENAI_API_KEY was provided." in str(e):
pytest.skip("API keys not configured")
raise

# Send the user request
response = await client.chat.completions.create(
model="gpt-4o",
messages=[
{
"role": "user",
"content": "why is the sky blue?",
}
],
temperature=0,
)

# Get the response message
msg = response.choices[0].message.content
assert msg is not None, "Chat agent didn't respond"
"""Test that chat agent responds correctly to 'why is the sky blue?'

Uses create_chat_model (the pydantic-ai code path) so this test exercises
the same Azure provider wiring as the rest of the codebase.
"""
model = create_chat_model()

result = await model.complete("why is the sky blue?")
assert isinstance(result, typechat.Success), f"Chat completion failed: {result}"
msg = result.value
assert msg, "Chat agent didn't respond"

print(f"Chat agent response: {msg}")

Expand Down
26 changes: 23 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_api_version_after_question_mark(
)
endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT")
assert version == "2025-01-01-preview"
assert endpoint == "https://myhost.openai.azure.com/openai/deployments/gpt-4"
assert endpoint == "https://myhost.openai.azure.com"

def test_api_version_after_ampersand(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""api-version preceded by & (not the first query parameter)."""
Expand Down Expand Up @@ -146,13 +146,13 @@ def test_query_string_stripped_from_endpoint(
def test_query_string_stripped_with_path(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Query string stripped even when endpoint includes a path."""
"""Query string and deployment path stripped from endpoint."""
monkeypatch.setenv(
"TEST_ENDPOINT",
"https://myhost.openai.azure.com/openai/deployments/gpt-4?api-version=2025-01-01-preview",
)
endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT")
assert endpoint == "https://myhost.openai.azure.com/openai/deployments/gpt-4"
assert endpoint == "https://myhost.openai.azure.com"
assert "?" not in endpoint
assert version == "2025-01-01-preview"

Expand All @@ -169,6 +169,26 @@ def test_query_string_stripped_multiple_params(
assert "foo" not in endpoint
assert version == "2024-06-01"

def test_bare_openai_path_stripped(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Trailing /openai without /deployments/ is stripped."""
monkeypatch.setenv(
"TEST_ENDPOINT",
"https://myhost.openai.azure.com/openai?api-version=2024-06-01",
)
endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT")
assert endpoint == "https://myhost.openai.azure.com"
assert version == "2024-06-01"

def test_apim_prefix_preserved(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""APIM prefix before /openai/deployments/ is kept."""
monkeypatch.setenv(
"TEST_ENDPOINT",
"https://apim.net/openai/openai/deployments/gpt-4o/chat/completions?api-version=2025-01-01-preview",
)
endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT")
assert endpoint == "https://apim.net/openai"
assert version == "2025-01-01-preview"

def test_no_api_version_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""RuntimeError when the endpoint has no api-version field."""
monkeypatch.setenv(
Expand Down
Loading
Loading