diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index b71cf90f..e4a6b056 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -26,9 +26,10 @@ import ( ) type interceptionBase struct { - id uuid.UUID - req *ChatCompletionNewParamsWrapper - cfg config.OpenAI + id uuid.UUID + providerName string + req *ChatCompletionNewParamsWrapper + cfg config.OpenAI // clientHeaders are the original HTTP headers from the client request. clientHeaders http.Header @@ -62,7 +63,7 @@ func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService } // Add API dump middleware if configured - if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.providerName, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } @@ -97,7 +98,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) attribute.String(tracing.RequestPath, r.URL.Path), attribute.String(tracing.InterceptionID, s.id.String()), attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())), - attribute.String(tracing.Provider, config.ProviderOpenAI), + attribute.String(tracing.Provider, s.providerName), attribute.String(tracing.Model, s.Model()), attribute.Bool(tracing.Streaming, streaming), } diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index 2816ed7a..9e398d05 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -31,6 +31,7 @@ type BlockingInterception struct { func NewBlockingInterceptor( id uuid.UUID, req *ChatCompletionNewParamsWrapper, + providerName string, cfg config.OpenAI, clientHeaders http.Header, authHeaderName string, @@ -38,6 +39,7 @@ func NewBlockingInterceptor( ) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ id: id, + providerName: providerName, req: req, cfg: cfg, clientHeaders: clientHeaders, diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index b550c8e6..1d705e36 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -36,6 +36,7 @@ type StreamingInterception struct { func NewStreamingInterceptor( id uuid.UUID, req *ChatCompletionNewParamsWrapper, + providerName string, cfg config.OpenAI, clientHeaders http.Header, authHeaderName string, @@ -43,6 +44,7 @@ func NewStreamingInterceptor( ) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ id: id, + providerName: providerName, req: req, cfg: cfg, clientHeaders: clientHeaders, diff --git a/intercept/chatcompletions/streaming_test.go b/intercept/chatcompletions/streaming_test.go index 233831e6..54c47336 100644 --- a/intercept/chatcompletions/streaming_test.go +++ b/intercept/chatcompletions/streaming_test.go @@ -86,7 +86,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) { httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil) tracer := otel.Tracer("test") - interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, httpReq.Header, "Authorization", tracer) + interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) interceptor.Setup(logger, &testutil.MockRecorder{}, nil) diff --git a/intercept/messages/base.go b/intercept/messages/base.go index af7d5915..ccbd91ba 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -63,8 +63,9 @@ var bedrockSupportedBetaFlags = map[string]bool{ } type interceptionBase struct { - id uuid.UUID - reqPayload MessagesRequestPayload + id uuid.UUID + providerName string + reqPayload MessagesRequestPayload cfg aibconfig.Anthropic bedrockCfg *aibconfig.AWSBedrock @@ -115,7 +116,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) attribute.String(tracing.RequestPath, r.URL.Path), attribute.String(tracing.InterceptionID, s.id.String()), attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())), - attribute.String(tracing.Provider, aibconfig.ProviderAnthropic), + attribute.String(tracing.Provider, s.providerName), attribute.String(tracing.Model, s.Model()), attribute.Bool(tracing.Streaming, streaming), attribute.Bool(tracing.IsBedrock, s.bedrockCfg != nil), @@ -232,7 +233,7 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio } // Add API dump middleware if configured - if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, aibconfig.ProviderAnthropic, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.providerName, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index e3c91303..19bae614 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -31,6 +31,7 @@ type BlockingInterception struct { func NewBlockingInterceptor( id uuid.UUID, reqPayload MessagesRequestPayload, + providerName string, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, @@ -39,6 +40,7 @@ func NewBlockingInterceptor( ) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ id: id, + providerName: providerName, reqPayload: reqPayload, cfg: cfg, bedrockCfg: bedrockCfg, diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index 0303677e..69f423d0 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -37,6 +37,7 @@ type StreamingInterception struct { func NewStreamingInterceptor( id uuid.UUID, reqPayload MessagesRequestPayload, + providerName string, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, @@ -45,6 +46,7 @@ func NewStreamingInterceptor( ) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ id: id, + providerName: providerName, reqPayload: reqPayload, cfg: cfg, bedrockCfg: bedrockCfg, diff --git a/intercept/responses/base.go b/intercept/responses/base.go index e127d1fb..32de1a08 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -36,7 +36,8 @@ const ( ) type responsesInterceptionBase struct { - id uuid.UUID + id uuid.UUID + providerName string // clientHeaders are the original HTTP headers from the client request. clientHeaders http.Header authHeaderName string @@ -71,7 +72,7 @@ func (i *responsesInterceptionBase) newResponsesService() responses.ResponseServ } // Add API dump middleware if configured - if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.providerName, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } @@ -101,7 +102,7 @@ func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streami attribute.String(tracing.RequestPath, r.URL.Path), attribute.String(tracing.InterceptionID, i.id.String()), attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())), - attribute.String(tracing.Provider, config.ProviderOpenAI), + attribute.String(tracing.Provider, i.providerName), attribute.String(tracing.Model, i.Model()), attribute.Bool(tracing.Streaming, streaming), } diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 62944310..d64adf9f 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -28,6 +28,7 @@ type BlockingResponsesInterceptor struct { func NewBlockingInterceptor( id uuid.UUID, reqPayload ResponsesRequestPayload, + providerName string, cfg config.OpenAI, clientHeaders http.Header, authHeaderName string, @@ -36,6 +37,7 @@ func NewBlockingInterceptor( return &BlockingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ id: id, + providerName: providerName, reqPayload: reqPayload, cfg: cfg, clientHeaders: clientHeaders, diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index dbd50673..359f82e8 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -35,6 +35,7 @@ type StreamingResponsesInterceptor struct { func NewStreamingInterceptor( id uuid.UUID, reqPayload ResponsesRequestPayload, + providerName string, cfg config.OpenAI, clientHeaders http.Header, authHeaderName string, @@ -43,6 +44,7 @@ func NewStreamingInterceptor( return &StreamingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ id: id, + providerName: providerName, reqPayload: reqPayload, cfg: cfg, clientHeaders: clientHeaders, diff --git a/internal/integrationtest/apidump_test.go b/internal/integrationtest/apidump_test.go index 77a4ea16..55231b05 100644 --- a/internal/integrationtest/apidump_test.go +++ b/internal/integrationtest/apidump_test.go @@ -25,10 +25,12 @@ func TestAPIDump(t *testing.T) { t.Parallel() cases := []struct { - name string - fixture []byte - providerFunc func(addr, dumpDir string) aibridge.Provider - path string + name string + fixture []byte + providerFunc func(addr, dumpDir string) aibridge.Provider + path string + headers http.Header + expectProviderDir string }{ { name: "anthropic", @@ -36,7 +38,8 @@ func TestAPIDump(t *testing.T) { providerFunc: func(addr, dumpDir string) aibridge.Provider { return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil) }, - path: pathAnthropicMessages, + path: pathAnthropicMessages, + expectProviderDir: config.ProviderAnthropic, }, { name: "openai_chat_completions", @@ -44,7 +47,8 @@ func TestAPIDump(t *testing.T) { providerFunc: func(addr, dumpDir string) aibridge.Provider { return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) }, - path: pathOpenAIChatCompletions, + path: pathOpenAIChatCompletions, + expectProviderDir: config.ProviderOpenAI, }, { name: "openai_responses", @@ -52,7 +56,56 @@ func TestAPIDump(t *testing.T) { providerFunc: func(addr, dumpDir string) aibridge.Provider { return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir)) }, - path: pathOpenAIResponses, + path: pathOpenAIResponses, + expectProviderDir: config.ProviderOpenAI, + }, + { + name: "copilot_chat_completions", + fixture: fixtures.OaiChatSimple, + providerFunc: func(addr, dumpDir string) aibridge.Provider { + return provider.NewCopilot(config.Copilot{BaseURL: addr, APIDumpDir: dumpDir}) + }, + path: pathCopilotChatCompletions, + headers: http.Header{"Authorization": {"Bearer test-copilot-token"}}, + expectProviderDir: config.ProviderCopilot, + }, + { + name: "copilot_responses", + fixture: fixtures.OaiResponsesBlockingSimple, + providerFunc: func(addr, dumpDir string) aibridge.Provider { + return provider.NewCopilot(config.Copilot{BaseURL: addr, APIDumpDir: dumpDir}) + }, + path: pathCopilotResponses, + headers: http.Header{"Authorization": {"Bearer test-copilot-token"}}, + expectProviderDir: config.ProviderCopilot, + }, + { + name: "copilot_custom_name_chat_completions", + fixture: fixtures.OaiChatSimple, + providerFunc: func(addr, dumpDir string) aibridge.Provider { + return provider.NewCopilot(config.Copilot{ + Name: "copilot-business", + BaseURL: addr, + APIDumpDir: dumpDir, + }) + }, + path: "/copilot-business/chat/completions", + headers: http.Header{"Authorization": {"Bearer test-copilot-token"}}, + expectProviderDir: "copilot-business", + }, + { + name: "copilot_custom_name_responses", + fixture: fixtures.OaiChatSimple, + providerFunc: func(addr, dumpDir string) aibridge.Provider { + return provider.NewCopilot(config.Copilot{ + Name: "copilot-enterprise", + BaseURL: addr, + APIDumpDir: dumpDir, + }) + }, + path: "/copilot-enterprise/chat/completions", + headers: http.Header{"Authorization": {"Bearer test-copilot-token"}}, + expectProviderDir: "copilot-enterprise", }, } @@ -74,7 +127,7 @@ func TestAPIDump(t *testing.T) { withCustomProvider(tc.providerFunc(srv.URL, dumpDir)), ) - resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request()) + resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers) require.Equal(t, http.StatusOK, resp.StatusCode) _, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -107,6 +160,12 @@ func TestAPIDump(t *testing.T) { require.NotEmpty(t, reqDumpFile, "request dump file should exist") require.NotEmpty(t, respDumpFile, "response dump file should exist") + // Verify dump files are in the correct provider subdirectory. + require.Contains(t, reqDumpFile, filepath.Join(dumpDir, tc.expectProviderDir)+"/", + "request dump should be in the %s provider directory", tc.expectProviderDir) + require.Contains(t, respDumpFile, filepath.Join(dumpDir, tc.expectProviderDir)+"/", + "response dump should be in the %s provider directory", tc.expectProviderDir) + // Verify request dump contains expected HTTP request format. reqDumpData, err := os.ReadFile(reqDumpFile) require.NoError(t, err) diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index 1d061c26..bb999d21 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -26,9 +26,11 @@ import ( ) const ( - pathAnthropicMessages = "/anthropic/v1/messages" - pathOpenAIChatCompletions = "/openai/v1/chat/completions" - pathOpenAIResponses = "/openai/v1/responses" + pathAnthropicMessages = "/anthropic/v1/messages" + pathOpenAIChatCompletions = "/openai/v1/chat/completions" + pathOpenAIResponses = "/openai/v1/responses" + pathCopilotChatCompletions = "/copilot/chat/completions" + pathCopilotResponses = "/copilot/responses" // providerBedrock identifies a Bedrock provider in [withProvider]. // other providers use config.Provider* constants. diff --git a/provider/anthropic.go b/provider/anthropic.go index c3a1235e..2800220c 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -142,9 +142,9 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr var interceptor intercept.Interceptor if reqPayload.Stream() { - interceptor = messages.NewStreamingInterceptor(id, reqPayload, cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) + interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) } else { - interceptor = messages.NewBlockingInterceptor(id, reqPayload, cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) + interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil diff --git a/provider/copilot.go b/provider/copilot.go index 7f60a6b5..eeeb74d4 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -156,9 +156,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } case routeCopilotResponses: @@ -172,9 +172,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if reqPayload.Stream() { - interceptor = responses.NewStreamingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } default: diff --git a/provider/openai.go b/provider/openai.go index b794e85e..1cd85912 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -126,9 +126,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } case routeResponses: @@ -141,9 +141,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace return nil, fmt.Errorf("unmarshal request body: %w", err) } if reqPayload.Stream() { - interceptor = responses.NewStreamingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } default: