From c100f386e96c0651c83624786379c739b2f0ad5c Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Mon, 30 Mar 2026 11:08:24 +0000 Subject: [PATCH] feat: add dynamic upstream resolution for Copilot provider --- api.go | 4 ++ bridge_test.go | 4 +- config/config.go | 8 +-- context/context.go | 15 +++++- internal/integrationtest/apidump_test.go | 2 +- provider/copilot.go | 41 ++++++++++++--- provider/copilot_test.go | 65 +++++++++++++++++++++++- 7 files changed, 124 insertions(+), 15 deletions(-) diff --git a/api.go b/api.go index e0486d77..ccb13564 100644 --- a/api.go +++ b/api.go @@ -44,6 +44,10 @@ func AsActor(ctx context.Context, actorID string, metadata recorder.Metadata) co return aibcontext.AsActor(ctx, actorID, metadata) } +func WithOriginalHost(ctx context.Context, host string) context.Context { + return aibcontext.WithOriginalHost(ctx, host) +} + func NewAnthropicProvider(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) provider.Provider { return provider.NewAnthropic(cfg, bedrockCfg) } diff --git a/bridge_test.go b/bridge_test.go index 1709be17..533f8585 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -62,7 +62,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { name: "copilot_no_base_path", requestPath: "/copilot/models", provider: func(baseURL string) provider.Provider { - return NewCopilotProvider(config.Copilot{BaseURL: baseURL}) + return NewCopilotProvider(config.Copilot{DefaultUpstreamURL: baseURL}) }, expectPath: "/models", }, @@ -71,7 +71,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) { baseURLPath: "/v1", requestPath: "/copilot/models", provider: func(baseURL string) provider.Provider { - return NewCopilotProvider(config.Copilot{BaseURL: baseURL}) + return NewCopilotProvider(config.Copilot{DefaultUpstreamURL: baseURL}) }, expectPath: "/v1/models", }, diff --git a/config/config.go b/config/config.go index 76c1c0a7..c7476873 100644 --- a/config/config.go +++ b/config/config.go @@ -77,7 +77,9 @@ func DefaultCircuitBreaker() CircuitBreaker { } type Copilot struct { - BaseURL string - APIDumpDir string - CircuitBreaker *CircuitBreaker + // DefaultUpstreamURL is the fallback upstream URL when no upstream + // header is provided in the request. + DefaultUpstreamURL string + APIDumpDir string + CircuitBreaker *CircuitBreaker } diff --git a/context/context.go b/context/context.go index ade88978..1ce9867b 100644 --- a/context/context.go +++ b/context/context.go @@ -7,7 +7,8 @@ import ( ) type ( - actorContextKey struct{} + actorContextKey struct{} + originalHostContextKey struct{} ) type Actor struct { @@ -28,6 +29,18 @@ func ActorFromContext(ctx context.Context) *Actor { return a } +// WithOriginalHost stores the original destination host in the context. +func WithOriginalHost(ctx context.Context, host string) context.Context { + return context.WithValue(ctx, originalHostContextKey{}, host) +} + +// OriginalHostFromContext retrieves the original destination host from the context. +// Returns an empty string if not set. +func OriginalHostFromContext(ctx context.Context) string { + h, _ := ctx.Value(originalHostContextKey{}).(string) + return h +} + // ActorIDFromContext safely extracts the actor ID from the context. // Returns an empty string if no actor is found. func ActorIDFromContext(ctx context.Context) string { diff --git a/internal/integrationtest/apidump_test.go b/internal/integrationtest/apidump_test.go index 77a4ea16..dd4d9183 100644 --- a/internal/integrationtest/apidump_test.go +++ b/internal/integrationtest/apidump_test.go @@ -170,7 +170,7 @@ func TestAPIDumpPassthrough(t *testing.T) { { name: "copilot", providerFunc: func(addr string, dumpDir string) aibridge.Provider { - return provider.NewCopilot(config.Copilot{BaseURL: addr, APIDumpDir: dumpDir}) + return provider.NewCopilot(config.Copilot{DefaultUpstreamURL: addr, APIDumpDir: dumpDir}) }, requestPath: "/copilot/models", expectDumpName: "-models-", diff --git a/provider/copilot.go b/provider/copilot.go index e414311d..58d6fa60 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/coder/aibridge/config" + aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/intercept" "github.com/coder/aibridge/intercept/chatcompletions" "github.com/coder/aibridge/intercept/responses" @@ -20,13 +21,22 @@ import ( ) const ( - copilotBaseURL = "https://api.individual.githubcopilot.com" + copilotIndividualUpstreamURL = "https://api.individual.githubcopilot.com" + copilotBusinessUpstreamURL = "https://api.business.githubcopilot.com" + copilotEnterpriseUpstreamURL = "https://api.enterprise.githubcopilot.com" // Copilot exposes an OpenAI-compatible API, including for Anthropic models. routeCopilotChatCompletions = "/chat/completions" routeCopilotResponses = "/responses" ) +// copilotUpstreams maps upstream URLs to their names. +var copilotUpstreams = map[string]string{ + copilotIndividualUpstreamURL: "individual", + copilotBusinessUpstreamURL: "business", + copilotEnterpriseUpstreamURL: "enterprise", +} + var copilotOpenErrorResponse = func() []byte { return []byte(`{"error":{"message":"circuit breaker is open","type":"server_error","code":"service_unavailable"}}`) } @@ -52,8 +62,8 @@ type Copilot struct { var _ Provider = &Copilot{} func NewCopilot(cfg config.Copilot) *Copilot { - if cfg.BaseURL == "" { - cfg.BaseURL = copilotBaseURL + if cfg.DefaultUpstreamURL == "" { + cfg.DefaultUpstreamURL = copilotIndividualUpstreamURL } if cfg.APIDumpDir == "" { cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR") @@ -72,11 +82,25 @@ func (p *Copilot) Name() string { } func (p *Copilot) BaseURL() string { - return p.cfg.BaseURL + return p.cfg.DefaultUpstreamURL } -func (p *Copilot) ResolveUpstream(_ *http.Request) intercept.ResolvedUpstream { - return intercept.ResolvedUpstream{Name: p.Name(), URL: p.cfg.BaseURL} +// ResolveUpstream determines the Copilot upstream based on the original +// destination host stored in the request context by coder. The host is +// mapped to a known upstream URL and name. +// If the host is absent or unknown, it falls back to the configured +// default upstream URL. +func (p *Copilot) ResolveUpstream(r *http.Request) intercept.ResolvedUpstream { + if host := aibcontext.OriginalHostFromContext(r.Context()); host != "" { + upstreamURL := "https://" + host + if name, ok := copilotUpstreams[upstreamURL]; ok { + return intercept.ResolvedUpstream{ + Name: config.ProviderCopilot + "-" + name, + URL: upstreamURL, + } + } + } + return intercept.ResolvedUpstream{Name: p.Name(), URL: p.cfg.DefaultUpstreamURL} } func (p *Copilot) RoutePrefix() string { @@ -119,6 +143,11 @@ func (p *Copilot) APIDumpDir() string { } func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) { + fmt.Println("################### aibridge copilot CreateInterceptor headers received:") + for k, v := range r.Header { + fmt.Printf(" %s: %s\n", k, v) + } + _, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor") defer tracing.EndSpanErr(span, &outErr) diff --git a/provider/copilot_test.go b/provider/copilot_test.go index d34a4208..907d779b 100644 --- a/provider/copilot_test.go +++ b/provider/copilot_test.go @@ -12,6 +12,7 @@ import ( "go.opentelemetry.io/otel" "github.com/coder/aibridge/config" + aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/internal/testutil" ) @@ -145,7 +146,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) { // Create provider with mock upstream URL provider := NewCopilot(config.Copilot{ - BaseURL: mockUpstream.URL, + DefaultUpstreamURL: mockUpstream.URL, }) body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": false}` @@ -236,7 +237,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) { // Create provider with mock upstream URL provider := NewCopilot(config.Copilot{ - BaseURL: mockUpstream.URL, + DefaultUpstreamURL: mockUpstream.URL, }) body := `{"model": "gpt-5-mini", "input": "hello", "stream": false}` @@ -325,3 +326,63 @@ func TestExtractCopilotHeaders(t *testing.T) { }) } } + +func TestCopilot_ResolveUpstream(t *testing.T) { + t.Parallel() + + provider := NewCopilot(config.Copilot{}) + + tests := []struct { + name string + host string + expectName string + expectURL string + }{ + { + name: "no_header_returns_default", + host: "", + expectName: config.ProviderCopilot, + expectURL: copilotIndividualUpstreamURL, + }, + { + name: "individual", + host: "api.individual.githubcopilot.com", + expectName: "copilot-individual", + expectURL: copilotIndividualUpstreamURL, + }, + { + name: "business", + host: "api.business.githubcopilot.com", + expectName: "copilot-business", + expectURL: copilotBusinessUpstreamURL, + }, + { + name: "enterprise", + host: "api.enterprise.githubcopilot.com", + expectName: "copilot-enterprise", + expectURL: copilotEnterpriseUpstreamURL, + }, + { + name: "unknown_host_returns_default", + host: "unknown.example.com", + expectName: config.ProviderCopilot, + expectURL: copilotIndividualUpstreamURL, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodPost, "/", nil) + if tc.host != "" { + ctx := aibcontext.WithOriginalHost(req.Context(), tc.host) + req = req.WithContext(ctx) + } + + upstream := provider.ResolveUpstream(req) + assert.Equal(t, tc.expectName, upstream.Name) + assert.Equal(t, tc.expectURL, upstream.URL) + }) + } +}