Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ import (
)

type interceptionBase struct {
id uuid.UUID
req *ChatCompletionNewParamsWrapper
cfg config.OpenAI
id uuid.UUID
providerName string
req *ChatCompletionNewParamsWrapper
cfg config.OpenAI

// clientHeaders are the original HTTP headers from the client request.
clientHeaders http.Header
Expand Down Expand Up @@ -62,7 +63,7 @@ func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService
}

// Add API dump middleware if configured
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.providerName, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
opts = append(opts, option.WithMiddleware(mw))
}

Expand Down Expand Up @@ -97,7 +98,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool)
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, s.id.String()),
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
attribute.String(tracing.Provider, config.ProviderOpenAI),
attribute.String(tracing.Provider, s.providerName),
attribute.String(tracing.Model, s.Model()),
attribute.Bool(tracing.Streaming, streaming),
}
Expand Down
2 changes: 2 additions & 0 deletions intercept/chatcompletions/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ type BlockingInterception struct {
func NewBlockingInterceptor(
id uuid.UUID,
req *ChatCompletionNewParamsWrapper,
providerName string,
cfg config.OpenAI,
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
providerName: providerName,
req: req,
cfg: cfg,
clientHeaders: clientHeaders,
Expand Down
2 changes: 2 additions & 0 deletions intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ type StreamingInterception struct {
func NewStreamingInterceptor(
id uuid.UUID,
req *ChatCompletionNewParamsWrapper,
providerName string,
cfg config.OpenAI,
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
) *StreamingInterception {
return &StreamingInterception{interceptionBase: interceptionBase{
id: id,
providerName: providerName,
req: req,
cfg: cfg,
clientHeaders: clientHeaders,
Expand Down
2 changes: 1 addition & 1 deletion intercept/chatcompletions/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)

tracer := otel.Tracer("test")
interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, httpReq.Header, "Authorization", tracer)
interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer)

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
Expand Down
9 changes: 5 additions & 4 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ var bedrockSupportedBetaFlags = map[string]bool{
}

type interceptionBase struct {
id uuid.UUID
reqPayload MessagesRequestPayload
id uuid.UUID
providerName string
reqPayload MessagesRequestPayload

cfg aibconfig.Anthropic
bedrockCfg *aibconfig.AWSBedrock
Expand Down Expand Up @@ -115,7 +116,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool)
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, s.id.String()),
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
attribute.String(tracing.Provider, aibconfig.ProviderAnthropic),
attribute.String(tracing.Provider, s.providerName),
attribute.String(tracing.Model, s.Model()),
attribute.Bool(tracing.Streaming, streaming),
attribute.Bool(tracing.IsBedrock, s.bedrockCfg != nil),
Expand Down Expand Up @@ -232,7 +233,7 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio
}

// Add API dump middleware if configured
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, aibconfig.ProviderAnthropic, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.providerName, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
opts = append(opts, option.WithMiddleware(mw))
}

Expand Down
2 changes: 2 additions & 0 deletions intercept/messages/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type BlockingInterception struct {
func NewBlockingInterceptor(
id uuid.UUID,
reqPayload MessagesRequestPayload,
providerName string,
cfg config.Anthropic,
bedrockCfg *config.AWSBedrock,
clientHeaders http.Header,
Expand All @@ -39,6 +40,7 @@ func NewBlockingInterceptor(
) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
providerName: providerName,
reqPayload: reqPayload,
cfg: cfg,
bedrockCfg: bedrockCfg,
Expand Down
2 changes: 2 additions & 0 deletions intercept/messages/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type StreamingInterception struct {
func NewStreamingInterceptor(
id uuid.UUID,
reqPayload MessagesRequestPayload,
providerName string,
cfg config.Anthropic,
bedrockCfg *config.AWSBedrock,
clientHeaders http.Header,
Expand All @@ -45,6 +46,7 @@ func NewStreamingInterceptor(
) *StreamingInterception {
return &StreamingInterception{interceptionBase: interceptionBase{
id: id,
providerName: providerName,
reqPayload: reqPayload,
cfg: cfg,
bedrockCfg: bedrockCfg,
Expand Down
7 changes: 4 additions & 3 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ const (
)

type responsesInterceptionBase struct {
id uuid.UUID
id uuid.UUID
providerName string
// clientHeaders are the original HTTP headers from the client request.
clientHeaders http.Header
authHeaderName string
Expand Down Expand Up @@ -71,7 +72,7 @@ func (i *responsesInterceptionBase) newResponsesService() responses.ResponseServ
}

// Add API dump middleware if configured
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.providerName, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
opts = append(opts, option.WithMiddleware(mw))
}

Expand Down Expand Up @@ -101,7 +102,7 @@ func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streami
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, i.id.String()),
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
attribute.String(tracing.Provider, config.ProviderOpenAI),
attribute.String(tracing.Provider, i.providerName),
attribute.String(tracing.Model, i.Model()),
attribute.Bool(tracing.Streaming, streaming),
}
Expand Down
2 changes: 2 additions & 0 deletions intercept/responses/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type BlockingResponsesInterceptor struct {
func NewBlockingInterceptor(
id uuid.UUID,
reqPayload ResponsesRequestPayload,
providerName string,
cfg config.OpenAI,
clientHeaders http.Header,
authHeaderName string,
Expand All @@ -36,6 +37,7 @@ func NewBlockingInterceptor(
return &BlockingResponsesInterceptor{
responsesInterceptionBase: responsesInterceptionBase{
id: id,
providerName: providerName,
reqPayload: reqPayload,
cfg: cfg,
clientHeaders: clientHeaders,
Expand Down
2 changes: 2 additions & 0 deletions intercept/responses/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type StreamingResponsesInterceptor struct {
func NewStreamingInterceptor(
id uuid.UUID,
reqPayload ResponsesRequestPayload,
providerName string,
cfg config.OpenAI,
clientHeaders http.Header,
authHeaderName string,
Expand All @@ -43,6 +44,7 @@ func NewStreamingInterceptor(
return &StreamingResponsesInterceptor{
responsesInterceptionBase: responsesInterceptionBase{
id: id,
providerName: providerName,
reqPayload: reqPayload,
cfg: cfg,
clientHeaders: clientHeaders,
Expand Down
75 changes: 67 additions & 8 deletions internal/integrationtest/apidump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,87 @@ func TestAPIDump(t *testing.T) {
t.Parallel()

cases := []struct {
name string
fixture []byte
providerFunc func(addr, dumpDir string) aibridge.Provider
path string
name string
fixture []byte
providerFunc func(addr, dumpDir string) aibridge.Provider
path string
headers http.Header
expectProviderDir string
}{
{
name: "anthropic",
fixture: fixtures.AntSimple,
providerFunc: func(addr, dumpDir string) aibridge.Provider {
return provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)
},
path: pathAnthropicMessages,
path: pathAnthropicMessages,
expectProviderDir: config.ProviderAnthropic,
},
{
name: "openai_chat_completions",
fixture: fixtures.OaiChatSimple,
providerFunc: func(addr, dumpDir string) aibridge.Provider {
return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))
},
path: pathOpenAIChatCompletions,
path: pathOpenAIChatCompletions,
expectProviderDir: config.ProviderOpenAI,
},
{
name: "openai_responses",
fixture: fixtures.OaiResponsesBlockingSimple,
providerFunc: func(addr, dumpDir string) aibridge.Provider {
return provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))
},
path: pathOpenAIResponses,
path: pathOpenAIResponses,
expectProviderDir: config.ProviderOpenAI,
},
{
name: "copilot_chat_completions",
fixture: fixtures.OaiChatSimple,
providerFunc: func(addr, dumpDir string) aibridge.Provider {
return provider.NewCopilot(config.Copilot{BaseURL: addr, APIDumpDir: dumpDir})
},
path: pathCopilotChatCompletions,
headers: http.Header{"Authorization": {"Bearer test-copilot-token"}},
expectProviderDir: config.ProviderCopilot,
},
{
name: "copilot_responses",
fixture: fixtures.OaiResponsesBlockingSimple,
providerFunc: func(addr, dumpDir string) aibridge.Provider {
return provider.NewCopilot(config.Copilot{BaseURL: addr, APIDumpDir: dumpDir})
},
path: pathCopilotResponses,
headers: http.Header{"Authorization": {"Bearer test-copilot-token"}},
expectProviderDir: config.ProviderCopilot,
},
{
name: "copilot_custom_name_chat_completions",
fixture: fixtures.OaiChatSimple,
providerFunc: func(addr, dumpDir string) aibridge.Provider {
return provider.NewCopilot(config.Copilot{
Name: "copilot-business",
BaseURL: addr,
APIDumpDir: dumpDir,
})
},
path: "/copilot-business/chat/completions",
headers: http.Header{"Authorization": {"Bearer test-copilot-token"}},
expectProviderDir: "copilot-business",
},
{
name: "copilot_custom_name_responses",
fixture: fixtures.OaiChatSimple,
providerFunc: func(addr, dumpDir string) aibridge.Provider {
return provider.NewCopilot(config.Copilot{
Name: "copilot-enterprise",
BaseURL: addr,
APIDumpDir: dumpDir,
})
},
path: "/copilot-enterprise/chat/completions",
headers: http.Header{"Authorization": {"Bearer test-copilot-token"}},
expectProviderDir: "copilot-enterprise",
},
}

Expand All @@ -74,7 +127,7 @@ func TestAPIDump(t *testing.T) {
withCustomProvider(tc.providerFunc(srv.URL, dumpDir)),
)

resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request())
resp := bridgeServer.makeRequest(t, http.MethodPost, tc.path, fix.Request(), tc.headers)
require.Equal(t, http.StatusOK, resp.StatusCode)
_, err := io.ReadAll(resp.Body)
require.NoError(t, err)
Expand Down Expand Up @@ -107,6 +160,12 @@ func TestAPIDump(t *testing.T) {
require.NotEmpty(t, reqDumpFile, "request dump file should exist")
require.NotEmpty(t, respDumpFile, "response dump file should exist")

// Verify dump files are in the correct provider subdirectory.
require.Contains(t, reqDumpFile, filepath.Join(dumpDir, tc.expectProviderDir)+"/",
"request dump should be in the %s provider directory", tc.expectProviderDir)
require.Contains(t, respDumpFile, filepath.Join(dumpDir, tc.expectProviderDir)+"/",
"response dump should be in the %s provider directory", tc.expectProviderDir)

// Verify request dump contains expected HTTP request format.
reqDumpData, err := os.ReadFile(reqDumpFile)
require.NoError(t, err)
Expand Down
8 changes: 5 additions & 3 deletions internal/integrationtest/setupbridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ import (
)

const (
pathAnthropicMessages = "/anthropic/v1/messages"
pathOpenAIChatCompletions = "/openai/v1/chat/completions"
pathOpenAIResponses = "/openai/v1/responses"
pathAnthropicMessages = "/anthropic/v1/messages"
pathOpenAIChatCompletions = "/openai/v1/chat/completions"
pathOpenAIResponses = "/openai/v1/responses"
pathCopilotChatCompletions = "/copilot/chat/completions"
pathCopilotResponses = "/copilot/responses"

// providerBedrock identifies a Bedrock provider in [withProvider].
// other providers use config.Provider* constants.
Expand Down
4 changes: 2 additions & 2 deletions provider/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr

var interceptor intercept.Interceptor
if reqPayload.Stream() {
interceptor = messages.NewStreamingInterceptor(id, reqPayload, cfg, p.bedrockCfg, r.Header, authHeaderName, tracer)
interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer)
} else {
interceptor = messages.NewBlockingInterceptor(id, reqPayload, cfg, p.bedrockCfg, r.Header, authHeaderName, tracer)
interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer)
}
span.SetAttributes(interceptor.TraceAttributes(r)...)
return interceptor, nil
Expand Down
8 changes: 4 additions & 4 deletions provider/copilot.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
}

if req.Stream {
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
} else {
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
}

case routeCopilotResponses:
Expand All @@ -172,9 +172,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
}

if reqPayload.Stream() {
interceptor = responses.NewStreamingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
} else {
interceptor = responses.NewBlockingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
}

default:
Expand Down
8 changes: 4 additions & 4 deletions provider/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
}

if req.Stream {
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
} else {
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
}

case routeResponses:
Expand All @@ -141,9 +141,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace
return nil, fmt.Errorf("unmarshal request body: %w", err)
}
if reqPayload.Stream() {
interceptor = responses.NewStreamingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
} else {
interceptor = responses.NewBlockingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
}

default:
Expand Down
Loading