From 32875245ad690595904d0a723c6828f89007270e Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Sat, 28 Mar 2026 16:30:42 +0000 Subject: [PATCH] fix: pass provider name to interceptors instead of hardcoding it --- bridge.go | 22 +++++----- circuitbreaker/circuitbreaker.go | 22 +++++----- circuitbreaker/circuitbreaker_test.go | 47 +++++++++++---------- intercept/chatcompletions/base.go | 15 ++++--- intercept/chatcompletions/blocking.go | 2 + intercept/chatcompletions/streaming.go | 2 + intercept/chatcompletions/streaming_test.go | 2 +- intercept/interceptor.go | 2 + intercept/messages/base.go | 13 ++++-- intercept/messages/blocking.go | 2 + intercept/messages/streaming.go | 2 + intercept/responses/base.go | 11 +++-- intercept/responses/blocking.go | 2 + intercept/responses/streaming.go | 2 + provider/anthropic.go | 6 ++- provider/copilot.go | 8 ++-- provider/openai.go | 8 ++-- 17 files changed, 102 insertions(+), 66 deletions(-) diff --git a/bridge.go b/bridge.go index 511a5edc..aeb73af7 100644 --- a/bridge.go +++ b/bridge.go @@ -73,7 +73,7 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re // Create per-provider circuit breaker if configured cfg := prov.CircuitBreakerConfig() providerName := prov.Name() - onChange := func(endpoint, model string, from, to gobreaker.State) { + onChange := func(providerName, endpoint, model string, from, to gobreaker.State) { logger.Info(context.Background(), "circuit breaker state change", slog.F("provider", providerName), slog.F("endpoint", endpoint), @@ -165,10 +165,12 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC return } + providerName := interceptor.ProviderName() + if m != nil { start := time.Now() defer func() { - m.InterceptionDuration.WithLabelValues(p.Name(), interceptor.Model()).Observe(time.Since(start).Seconds()) + m.InterceptionDuration.WithLabelValues(providerName, interceptor.Model()).Observe(time.Since(start).Seconds()) }() } @@ -187,7 +189,7 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC // Record usage in the background to not block request flow. asyncRecorder := recorder.NewAsyncRecorder(logger, rec, recordingTimeout) asyncRecorder.WithMetrics(m) - asyncRecorder.WithProvider(p.Name()) + asyncRecorder.WithProvider(providerName) asyncRecorder.WithModel(interceptor.Model()) asyncRecorder.WithInitiatorID(actor.ID) asyncRecorder.WithClient(string(client)) @@ -198,7 +200,7 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC InitiatorID: actor.ID, Metadata: actor.Metadata, Model: interceptor.Model(), - Provider: p.Name(), + Provider: providerName, UserAgent: r.UserAgent(), Client: string(client), ClientSessionID: sessionID, @@ -213,7 +215,7 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name())) log := logger.With( slog.F("route", route), - slog.F("provider", p.Name()), + slog.F("provider", providerName), slog.F("interception_id", interceptor.ID()), slog.F("user_agent", r.UserAgent()), slog.F("streaming", interceptor.Streaming()), @@ -221,24 +223,24 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC log.Debug(ctx, "interception started") if m != nil { - m.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Add(1) + m.InterceptionsInflight.WithLabelValues(providerName, interceptor.Model(), route).Add(1) defer func() { - m.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Sub(1) + m.InterceptionsInflight.WithLabelValues(providerName, interceptor.Model(), route).Sub(1) }() } // Process request with circuit breaker protection if configured - if err := cbs.Execute(route, interceptor.Model(), w, func(rw http.ResponseWriter) error { + if err := cbs.Execute(providerName, route, interceptor.Model(), w, func(rw http.ResponseWriter) error { return interceptor.ProcessRequest(rw, r) }); err != nil { if m != nil { - m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID, string(client)).Add(1) + m.InterceptionCount.WithLabelValues(providerName, interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID, string(client)).Add(1) } span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err)) log.Warn(ctx, "interception failed", slog.Error(err)) } else { if m != nil { - m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID, string(client)).Add(1) + m.InterceptionCount.WithLabelValues(providerName, interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID, string(client)).Add(1) } log.Debug(ctx, "interception ended") } diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go index 4be1d2b8..a3f06a7b 100644 --- a/circuitbreaker/circuitbreaker.go +++ b/circuitbreaker/circuitbreaker.go @@ -33,8 +33,8 @@ func DefaultIsFailure(statusCode int) bool { type ProviderCircuitBreakers struct { provider string config config.CircuitBreaker - breakers sync.Map // "endpoint:model" -> *gobreaker.CircuitBreaker[struct{}] - onChange func(endpoint, model string, from, to gobreaker.State) + breakers sync.Map // "providerName:endpoint:model" -> *gobreaker.CircuitBreaker[struct{}] + onChange func(providerName, endpoint, model string, from, to gobreaker.State) metrics *metrics.Metrics } @@ -42,7 +42,7 @@ type ProviderCircuitBreakers struct { // Returns nil if cfg is nil (no circuit breaker protection). // onChange is called when circuit state changes. // metrics is used to record circuit breaker reject counts (can be nil). -func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(endpoint, model string, from, to gobreaker.State), m *metrics.Metrics) *ProviderCircuitBreakers { +func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(providerName, endpoint, model string, from, to gobreaker.State), m *metrics.Metrics) *ProviderCircuitBreakers { if cfg == nil { return nil } @@ -71,15 +71,15 @@ func (p *ProviderCircuitBreakers) openErrorResponse() []byte { return []byte(`{"error":"circuit breaker is open"}`) } -// Get returns the circuit breaker for an endpoint/model tuple, creating it if needed. -func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.CircuitBreaker[struct{}] { - key := endpoint + ":" + model +// Get returns the circuit breaker for a providerName/endpoint/model tuple, creating it if needed. +func (p *ProviderCircuitBreakers) Get(providerName, endpoint, model string) *gobreaker.CircuitBreaker[struct{}] { + key := providerName + ":" + endpoint + ":" + model if v, ok := p.breakers.Load(key); ok { return v.(*gobreaker.CircuitBreaker[struct{}]) } settings := gobreaker.Settings{ - Name: p.provider + ":" + key, + Name: key, MaxRequests: p.config.MaxRequests, Interval: p.config.Interval, Timeout: p.config.Timeout, @@ -88,7 +88,7 @@ func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.Circuit }, OnStateChange: func(_ string, from, to gobreaker.State) { if p.onChange != nil { - p.onChange(endpoint, model, from, to) + p.onChange(providerName, endpoint, model, from, to) } }, } @@ -139,12 +139,12 @@ func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { // Otherwise, it returns the handler's error (or nil on success). // The handler receives a wrapped ResponseWriter that captures the status code. // If the receiver is nil (no circuit breaker configured), the handler is called directly. -func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.ResponseWriter, handler func(http.ResponseWriter) error) error { +func (p *ProviderCircuitBreakers) Execute(providerName, endpoint, model string, w http.ResponseWriter, handler func(http.ResponseWriter) error) error { if p == nil { return handler(w) } - cb := p.Get(endpoint, model) + cb := p.Get(providerName, endpoint, model) // Wrap response writer to capture status code sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} @@ -160,7 +160,7 @@ func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.Respons if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { if p.metrics != nil { - p.metrics.CircuitBreakerRejects.WithLabelValues(p.provider, endpoint, model).Inc() + p.metrics.CircuitBreakerRejects.WithLabelValues(providerName, endpoint, model).Inc() } w.Header().Set("Content-Type", "application/json") w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds()))) diff --git a/circuitbreaker/circuitbreaker_test.go b/circuitbreaker/circuitbreaker_test.go index 18913718..96eafe8d 100644 --- a/circuitbreaker/circuitbreaker_test.go +++ b/circuitbreaker/circuitbreaker_test.go @@ -24,7 +24,7 @@ func TestExecute_PerModelIsolation(t *testing.T) { Interval: time.Minute, Timeout: time.Minute, MaxRequests: 1, - }, func(endpoint, model string, from, to gobreaker.State) {}, nil) + }, func(providerName, endpoint, model string, from, to gobreaker.State) {}, nil) endpoint := "/v1/messages" sonnetModel := "claude-sonnet-4-20250514" @@ -32,7 +32,7 @@ func TestExecute_PerModelIsolation(t *testing.T) { // Trip circuit on sonnet model (returns 429) w := httptest.NewRecorder() - err := cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) error { + err := cbs.Execute(config.ProviderAnthropic, endpoint, sonnetModel, w, func(rw http.ResponseWriter) error { sonnetCalls.Add(1) rw.WriteHeader(http.StatusTooManyRequests) return nil @@ -42,7 +42,7 @@ func TestExecute_PerModelIsolation(t *testing.T) { // Second sonnet request should be blocked by circuit breaker w = httptest.NewRecorder() - err = cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) error { + err = cbs.Execute(config.ProviderAnthropic, endpoint, sonnetModel, w, func(rw http.ResponseWriter) error { sonnetCalls.Add(1) rw.WriteHeader(http.StatusOK) return nil @@ -53,7 +53,7 @@ func TestExecute_PerModelIsolation(t *testing.T) { // Haiku model on same endpoint should still work (independent circuit) w = httptest.NewRecorder() - err = cbs.Execute(endpoint, haikuModel, w, func(rw http.ResponseWriter) error { + err = cbs.Execute(config.ProviderAnthropic, endpoint, haikuModel, w, func(rw http.ResponseWriter) error { haikuCalls.Add(1) rw.WriteHeader(http.StatusOK) return nil @@ -73,13 +73,13 @@ func TestExecute_PerEndpointIsolation(t *testing.T) { Interval: time.Minute, Timeout: time.Minute, MaxRequests: 1, - }, func(endpoint, model string, from, to gobreaker.State) {}, nil) + }, func(providerName, endpoint, model string, from, to gobreaker.State) {}, nil) model := "test-model" // Trip circuit on /v1/messages endpoint (returns 429) w := httptest.NewRecorder() - err := cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) error { + err := cbs.Execute(config.ProviderAnthropic, "/v1/messages", model, w, func(rw http.ResponseWriter) error { messagesCalls.Add(1) rw.WriteHeader(http.StatusTooManyRequests) return nil @@ -89,7 +89,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) { // Second /v1/messages request should be blocked w = httptest.NewRecorder() - err = cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) error { + err = cbs.Execute(config.ProviderAnthropic, "/v1/messages", model, w, func(rw http.ResponseWriter) error { messagesCalls.Add(1) rw.WriteHeader(http.StatusOK) return nil @@ -100,7 +100,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) { // /v1/chat/completions on same model should still work (different endpoint) w = httptest.NewRecorder() - err = cbs.Execute("/v1/chat/completions", model, w, func(rw http.ResponseWriter) error { + err = cbs.Execute(config.ProviderOpenAI, "/v1/chat/completions", model, w, func(rw http.ResponseWriter) error { completionsCalls.Add(1) rw.WriteHeader(http.StatusOK) return nil @@ -123,11 +123,11 @@ func TestExecute_CustomIsFailure(t *testing.T) { IsFailure: func(statusCode int) bool { return statusCode == http.StatusBadGateway }, - }, func(endpoint, model string, from, to gobreaker.State) {}, nil) + }, func(providerName, endpoint, model string, from, to gobreaker.State) {}, nil) // First request returns 502, trips circuit w := httptest.NewRecorder() - err := cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) error { + err := cbs.Execute(config.ProviderAnthropic, "/v1/messages", "test-model", w, func(rw http.ResponseWriter) error { calls.Add(1) rw.WriteHeader(http.StatusBadGateway) return nil @@ -137,7 +137,7 @@ func TestExecute_CustomIsFailure(t *testing.T) { // Second request should be blocked w = httptest.NewRecorder() - err = cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) error { + err = cbs.Execute(config.ProviderAnthropic, "/v1/messages", "test-model", w, func(rw http.ResponseWriter) error { calls.Add(1) rw.WriteHeader(http.StatusOK) return nil @@ -151,10 +151,11 @@ func TestExecute_OnStateChange(t *testing.T) { t.Parallel() var stateChanges []struct { - endpoint string - model string - from gobreaker.State - to gobreaker.State + providerName string + endpoint string + model string + from gobreaker.State + to gobreaker.State } cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{ @@ -162,13 +163,14 @@ func TestExecute_OnStateChange(t *testing.T) { Interval: time.Minute, Timeout: time.Minute, MaxRequests: 1, - }, func(endpoint, model string, from, to gobreaker.State) { + }, func(providerName, endpoint, model string, from, to gobreaker.State) { stateChanges = append(stateChanges, struct { - endpoint string - model string - from gobreaker.State - to gobreaker.State - }{endpoint, model, from, to}) + providerName string + endpoint string + model string + from gobreaker.State + to gobreaker.State + }{providerName, endpoint, model, from, to}) }, nil) endpoint := "/v1/messages" @@ -176,13 +178,14 @@ func TestExecute_OnStateChange(t *testing.T) { // Trip circuit w := httptest.NewRecorder() - cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error { + cbs.Execute(config.ProviderAnthropic, endpoint, model, w, func(rw http.ResponseWriter) error { rw.WriteHeader(http.StatusTooManyRequests) return nil }) // Verify state change callback was called with correct parameters assert.Len(t, stateChanges, 1) + assert.Equal(t, config.ProviderAnthropic, stateChanges[0].providerName) assert.Equal(t, endpoint, stateChanges[0].endpoint) assert.Equal(t, model, stateChanges[0].model) assert.Equal(t, gobreaker.StateClosed, stateChanges[0].from) diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index b71cf90f..03d9131f 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -26,9 +26,10 @@ import ( ) type interceptionBase struct { - id uuid.UUID - req *ChatCompletionNewParamsWrapper - cfg config.OpenAI + id uuid.UUID + providerName string + req *ChatCompletionNewParamsWrapper + cfg config.OpenAI // clientHeaders are the original HTTP headers from the client request. clientHeaders http.Header @@ -62,7 +63,7 @@ func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService } // Add API dump middleware if configured - if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.ProviderName(), i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } @@ -73,6 +74,10 @@ func (i *interceptionBase) ID() uuid.UUID { return i.id } +func (i *interceptionBase) ProviderName() string { + return i.providerName +} + func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { i.logger = logger i.recorder = recorder @@ -97,7 +102,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) attribute.String(tracing.RequestPath, r.URL.Path), attribute.String(tracing.InterceptionID, s.id.String()), attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())), - attribute.String(tracing.Provider, config.ProviderOpenAI), + attribute.String(tracing.Provider, s.ProviderName()), attribute.String(tracing.Model, s.Model()), attribute.Bool(tracing.Streaming, streaming), } diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index 2816ed7a..9e398d05 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -31,6 +31,7 @@ type BlockingInterception struct { func NewBlockingInterceptor( id uuid.UUID, req *ChatCompletionNewParamsWrapper, + providerName string, cfg config.OpenAI, clientHeaders http.Header, authHeaderName string, @@ -38,6 +39,7 @@ func NewBlockingInterceptor( ) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ id: id, + providerName: providerName, req: req, cfg: cfg, clientHeaders: clientHeaders, diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index b550c8e6..1d705e36 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -36,6 +36,7 @@ type StreamingInterception struct { func NewStreamingInterceptor( id uuid.UUID, req *ChatCompletionNewParamsWrapper, + providerName string, cfg config.OpenAI, clientHeaders http.Header, authHeaderName string, @@ -43,6 +44,7 @@ func NewStreamingInterceptor( ) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ id: id, + providerName: providerName, req: req, cfg: cfg, clientHeaders: clientHeaders, diff --git a/intercept/chatcompletions/streaming_test.go b/intercept/chatcompletions/streaming_test.go index 233831e6..54c47336 100644 --- a/intercept/chatcompletions/streaming_test.go +++ b/intercept/chatcompletions/streaming_test.go @@ -86,7 +86,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) { httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil) tracer := otel.Tracer("test") - interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, httpReq.Header, "Authorization", tracer) + interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) interceptor.Setup(logger, &testutil.MockRecorder{}, nil) diff --git a/intercept/interceptor.go b/intercept/interceptor.go index cbd29d62..e8db619a 100644 --- a/intercept/interceptor.go +++ b/intercept/interceptor.go @@ -25,6 +25,8 @@ type Interceptor interface { Streaming() bool // TraceAttributes returns tracing attributes for this [Interceptor] TraceAttributes(*http.Request) []attribute.KeyValue + // ProviderName returns the provider name for this interception. + ProviderName() string // CorrelatingToolCallID returns the ID of a tool call result submitted // in the request, if present. This is used to correlate the current // interception back to the previous interception that issued those tool diff --git a/intercept/messages/base.go b/intercept/messages/base.go index af7d5915..214fb968 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -63,8 +63,9 @@ var bedrockSupportedBetaFlags = map[string]bool{ } type interceptionBase struct { - id uuid.UUID - reqPayload MessagesRequestPayload + id uuid.UUID + providerName string + reqPayload MessagesRequestPayload cfg aibconfig.Anthropic bedrockCfg *aibconfig.AWSBedrock @@ -84,6 +85,10 @@ func (i *interceptionBase) ID() uuid.UUID { return i.id } +func (i *interceptionBase) ProviderName() string { + return i.providerName +} + func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { i.logger = logger i.recorder = recorder @@ -115,7 +120,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) attribute.String(tracing.RequestPath, r.URL.Path), attribute.String(tracing.InterceptionID, s.id.String()), attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())), - attribute.String(tracing.Provider, aibconfig.ProviderAnthropic), + attribute.String(tracing.Provider, s.ProviderName()), attribute.String(tracing.Model, s.Model()), attribute.Bool(tracing.Streaming, streaming), attribute.Bool(tracing.IsBedrock, s.bedrockCfg != nil), @@ -232,7 +237,7 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio } // Add API dump middleware if configured - if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, aibconfig.ProviderAnthropic, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.ProviderName(), i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } diff --git a/intercept/messages/blocking.go b/intercept/messages/blocking.go index 7ed267cd..3072b125 100644 --- a/intercept/messages/blocking.go +++ b/intercept/messages/blocking.go @@ -31,6 +31,7 @@ type BlockingInterception struct { func NewBlockingInterceptor( id uuid.UUID, reqPayload MessagesRequestPayload, + providerName string, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, @@ -39,6 +40,7 @@ func NewBlockingInterceptor( ) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ id: id, + providerName: providerName, reqPayload: reqPayload, cfg: cfg, bedrockCfg: bedrockCfg, diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index d317c55b..395c6021 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -37,6 +37,7 @@ type StreamingInterception struct { func NewStreamingInterceptor( id uuid.UUID, reqPayload MessagesRequestPayload, + providerName string, cfg config.Anthropic, bedrockCfg *config.AWSBedrock, clientHeaders http.Header, @@ -45,6 +46,7 @@ func NewStreamingInterceptor( ) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ id: id, + providerName: providerName, reqPayload: reqPayload, cfg: cfg, bedrockCfg: bedrockCfg, diff --git a/intercept/responses/base.go b/intercept/responses/base.go index e127d1fb..1f50fceb 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -36,7 +36,8 @@ const ( ) type responsesInterceptionBase struct { - id uuid.UUID + id uuid.UUID + providerName string // clientHeaders are the original HTTP headers from the client request. clientHeaders http.Header authHeaderName string @@ -71,7 +72,7 @@ func (i *responsesInterceptionBase) newResponsesService() responses.ResponseServ } // Add API dump middleware if configured - if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.ProviderName(), i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } @@ -82,6 +83,10 @@ func (i *responsesInterceptionBase) ID() uuid.UUID { return i.id } +func (i *responsesInterceptionBase) ProviderName() string { + return i.providerName +} + func (i *responsesInterceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) { i.logger = logger.With(slog.F("model", i.Model())) i.recorder = recorder @@ -101,7 +106,7 @@ func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streami attribute.String(tracing.RequestPath, r.URL.Path), attribute.String(tracing.InterceptionID, i.id.String()), attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())), - attribute.String(tracing.Provider, config.ProviderOpenAI), + attribute.String(tracing.Provider, i.ProviderName()), attribute.String(tracing.Model, i.Model()), attribute.Bool(tracing.Streaming, streaming), } diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 62944310..d64adf9f 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -28,6 +28,7 @@ type BlockingResponsesInterceptor struct { func NewBlockingInterceptor( id uuid.UUID, reqPayload ResponsesRequestPayload, + providerName string, cfg config.OpenAI, clientHeaders http.Header, authHeaderName string, @@ -36,6 +37,7 @@ func NewBlockingInterceptor( return &BlockingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ id: id, + providerName: providerName, reqPayload: reqPayload, cfg: cfg, clientHeaders: clientHeaders, diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index dbd50673..359f82e8 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -35,6 +35,7 @@ type StreamingResponsesInterceptor struct { func NewStreamingInterceptor( id uuid.UUID, reqPayload ResponsesRequestPayload, + providerName string, cfg config.OpenAI, clientHeaders http.Header, authHeaderName string, @@ -43,6 +44,7 @@ func NewStreamingInterceptor( return &StreamingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ id: id, + providerName: providerName, reqPayload: reqPayload, cfg: cfg, clientHeaders: clientHeaders, diff --git a/provider/anthropic.go b/provider/anthropic.go index 4825db2d..e980a163 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -133,11 +133,13 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr authHeaderName = "Authorization" } + // TODO(ssncferreira): when Bedrock is added as a separate provider, pass the + // resolved provider name instead of p.Name() here. var interceptor intercept.Interceptor if reqPayload.Stream() { - interceptor = messages.NewStreamingInterceptor(id, reqPayload, cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) + interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) } else { - interceptor = messages.NewBlockingInterceptor(id, reqPayload, cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) + interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer) } span.SetAttributes(interceptor.TraceAttributes(r)...) return interceptor, nil diff --git a/provider/copilot.go b/provider/copilot.go index 99b0fca3..abf19c82 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -149,9 +149,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } case routeCopilotResponses: @@ -165,9 +165,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if reqPayload.Stream() { - interceptor = responses.NewStreamingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } default: diff --git a/provider/openai.go b/provider/openai.go index cfedf009..a1624e10 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -119,9 +119,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } case routeResponses: @@ -134,9 +134,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace return nil, fmt.Errorf("unmarshal request body: %w", err) } if reqPayload.Stream() { - interceptor = responses.NewStreamingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, reqPayload, cfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) } default: