Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
try:
from copilot import CopilotClient, CopilotSession, SubprocessConfig
from copilot.generated.session_events import PermissionRequest, SessionEvent, SessionEventType
from copilot.session import MCPServerConfig, PermissionRequestResult, SystemMessageConfig
from copilot.session import MCPServerConfig, PermissionRequestResult, ProviderConfig, SystemMessageConfig
from copilot.tools import Tool as CopilotTool
from copilot.tools import ToolInvocation, ToolResult
except ImportError as _copilot_import_error:
Expand Down Expand Up @@ -120,6 +120,12 @@ class GitHubCopilotOptions(TypedDict, total=False):
Supports both local (stdio) and remote (HTTP/SSE) servers.
"""

provider: ProviderConfig
"""Custom API provider configuration for BYOK (Bring Your Own Key) scenarios.
Allows routing requests through your own OpenAI, Azure, or Anthropic endpoint
instead of the default GitHub Copilot backend.
"""


OptionsT = TypeVar(
"OptionsT",
Expand Down Expand Up @@ -232,6 +238,7 @@ def __init__(
log_level = opts.pop("log_level", None)
on_permission_request: PermissionHandlerType | None = opts.pop("on_permission_request", None)
mcp_servers: dict[str, MCPServerConfig] | None = opts.pop("mcp_servers", None)
provider: ProviderConfig | None = opts.pop("provider", None)

self._settings = load_settings(
GitHubCopilotSettings,
Expand All @@ -247,6 +254,7 @@ def __init__(
self._tools = normalize_tools(tools)
self._permission_handler = on_permission_request
self._mcp_servers = mcp_servers
self._provider = provider
self._default_options = opts
self._started = False

Expand Down Expand Up @@ -730,6 +738,7 @@ async def _create_session(
opts.get("on_permission_request") or self._permission_handler or _deny_all_permissions
)
mcp_servers = opts.get("mcp_servers") or self._mcp_servers or None
provider = opts.get("provider") or self._provider or None
tools = self._prepare_tools(self._tools) if self._tools else None

return await self._client.create_session(
Expand All @@ -739,6 +748,7 @@ async def _create_session(
system_message=system_message or None,
tools=tools or None,
mcp_servers=mcp_servers or None,
provider=provider or None,
)

async def _resume_session(self, session_id: str, streaming: bool) -> CopilotSession:
Expand All @@ -755,4 +765,5 @@ async def _resume_session(self, session_id: str, streaming: bool) -> CopilotSess
streaming=streaming,
tools=tools or None,
mcp_servers=self._mcp_servers or None,
provider=self._provider or None,
)
193 changes: 193 additions & 0 deletions python/packages/github_copilot/tests/test_github_copilot_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,7 @@ async def test_session_resumed_for_same_session(
streaming=unittest.mock.ANY,
tools=unittest.mock.ANY,
mcp_servers=unittest.mock.ANY,
provider=unittest.mock.ANY,
)

async def test_session_config_includes_model(
Expand Down Expand Up @@ -1084,6 +1085,198 @@ async def test_session_config_excludes_mcp_servers_when_not_set(
assert config["mcp_servers"] is None


class TestGitHubCopilotAgentProvider:
"""Test cases for provider configuration (BYOK / Managed Identity)."""

async def test_provider_passed_to_create_session(
self,
mock_client: MagicMock,
) -> None:
"""Test that provider config is passed through to create_session."""
from copilot.session import ProviderConfig

provider: ProviderConfig = {
"type": "azure",
"base_url": "https://my-resource.openai.azure.com",
"bearer_token": "test-token",
}

agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent(
client=mock_client,
default_options={"provider": provider},
)
await agent.start()

await agent._get_or_create_session(AgentSession()) # type: ignore

call_args = mock_client.create_session.call_args
config = call_args.kwargs
assert config["provider"]["type"] == "azure"
assert config["provider"]["base_url"] == "https://my-resource.openai.azure.com"
assert config["provider"]["bearer_token"] == "test-token"

async def test_provider_passed_to_resume_session(
self,
mock_client: MagicMock,
) -> None:
"""Test that provider config is passed through to resume_session."""
from copilot.session import ProviderConfig

provider: ProviderConfig = {
"type": "azure",
"base_url": "https://my-resource.openai.azure.com",
"bearer_token": "test-token",
}

agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent(
client=mock_client,
default_options={"provider": provider},
)
await agent.start()

session = AgentSession()
session.service_session_id = "existing-session-id"

await agent._get_or_create_session(session) # type: ignore

mock_client.resume_session.assert_called_once()
call_args = mock_client.resume_session.call_args
config = call_args.kwargs
assert config["provider"]["type"] == "azure"

async def test_session_config_excludes_provider_when_not_set(
self,
mock_client: MagicMock,
) -> None:
"""Test that provider is None in session config when not set."""
agent = GitHubCopilotAgent(client=mock_client)
await agent.start()

await agent._get_or_create_session(AgentSession()) # type: ignore

call_args = mock_client.create_session.call_args
config = call_args.kwargs
assert config["provider"] is None

async def test_resume_session_excludes_provider_when_not_set(
self,
mock_client: MagicMock,
) -> None:
"""Test that provider is None in resume session config when not set."""
agent = GitHubCopilotAgent(client=mock_client)
await agent.start()

session = AgentSession()
session.service_session_id = "existing-session-id"

await agent._get_or_create_session(session) # type: ignore

call_args = mock_client.resume_session.call_args
config = call_args.kwargs
assert config["provider"] is None

async def test_runtime_provider_takes_precedence(
self,
mock_client: MagicMock,
) -> None:
"""Test that runtime provider options override default_options provider."""
from copilot.session import ProviderConfig

default_provider: ProviderConfig = {
"type": "azure",
"base_url": "https://default.openai.azure.com",
"bearer_token": "default-token",
}
runtime_provider: ProviderConfig = {
"type": "openai",
"base_url": "https://runtime.openai.com",
"api_key": "runtime-key",
}

agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent(
client=mock_client,
default_options={"provider": default_provider},
)
await agent.start()

await agent._get_or_create_session( # type: ignore
AgentSession(),
runtime_options={"provider": runtime_provider},
)

call_args = mock_client.create_session.call_args
config = call_args.kwargs
assert config["provider"]["type"] == "openai"
assert config["provider"]["base_url"] == "https://runtime.openai.com"

async def test_provider_not_leaked_into_default_options(
self,
mock_client: MagicMock,
) -> None:
"""Test that provider is popped from opts and not left in _default_options."""
from copilot.session import ProviderConfig

provider: ProviderConfig = {
"type": "azure",
"base_url": "https://my-resource.openai.azure.com",
"bearer_token": "test-token",
}

agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent(
client=mock_client,
default_options={"provider": provider, "model": "gpt-5"},
)

assert "provider" not in agent._default_options
assert agent._provider is not None
assert agent._provider["type"] == "azure"

async def test_provider_coexists_with_other_options(
self,
mock_client: MagicMock,
) -> None:
"""Test that provider works alongside model, tools, and mcp_servers."""
from copilot.session import MCPServerConfig, ProviderConfig

provider: ProviderConfig = {
"type": "azure",
"base_url": "https://my-resource.openai.azure.com",
"bearer_token": "test-token",
}
mcp_servers: dict[str, MCPServerConfig] = {
"test-server": {
"type": "stdio",
"command": "echo",
"args": ["hello"],
"tools": ["*"],
},
}

def my_tool(arg: str) -> str:
"""A test tool."""
return arg

agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent(
client=mock_client,
tools=[my_tool],
default_options={
"model": "gpt-5",
"provider": provider,
"mcp_servers": mcp_servers,
},
)
await agent.start()

await agent._get_or_create_session(AgentSession()) # type: ignore

call_args = mock_client.create_session.call_args
config = call_args.kwargs
assert config["provider"]["type"] == "azure"
assert config["model"] == "gpt-5"
assert config["mcp_servers"] is not None
assert config["tools"] is not None


class TestGitHubCopilotAgentToolConversion:
"""Test cases for tool conversion."""

Expand Down
Loading