Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re
// Create per-provider circuit breaker if configured
cfg := prov.CircuitBreakerConfig()
providerName := prov.Name()
onChange := func(endpoint, model string, from, to gobreaker.State) {
onChange := func(providerName, endpoint, model string, from, to gobreaker.State) {
logger.Info(context.Background(), "circuit breaker state change",
slog.F("provider", providerName),
slog.F("endpoint", endpoint),
Expand Down Expand Up @@ -165,10 +165,12 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
return
}

providerName := interceptor.ProviderName()

if m != nil {
start := time.Now()
defer func() {
m.InterceptionDuration.WithLabelValues(p.Name(), interceptor.Model()).Observe(time.Since(start).Seconds())
m.InterceptionDuration.WithLabelValues(providerName, interceptor.Model()).Observe(time.Since(start).Seconds())
}()
}

Expand All @@ -187,7 +189,7 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
// Record usage in the background to not block request flow.
asyncRecorder := recorder.NewAsyncRecorder(logger, rec, recordingTimeout)
asyncRecorder.WithMetrics(m)
asyncRecorder.WithProvider(p.Name())
asyncRecorder.WithProvider(providerName)
asyncRecorder.WithModel(interceptor.Model())
asyncRecorder.WithInitiatorID(actor.ID)
asyncRecorder.WithClient(string(client))
Expand All @@ -198,7 +200,7 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
InitiatorID: actor.ID,
Metadata: actor.Metadata,
Model: interceptor.Model(),
Provider: p.Name(),
Provider: providerName,
UserAgent: r.UserAgent(),
Client: string(client),
ClientSessionID: sessionID,
Expand All @@ -213,32 +215,32 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name()))
log := logger.With(
slog.F("route", route),
slog.F("provider", p.Name()),
slog.F("provider", providerName),
slog.F("interception_id", interceptor.ID()),
slog.F("user_agent", r.UserAgent()),
slog.F("streaming", interceptor.Streaming()),
)

log.Debug(ctx, "interception started")
if m != nil {
m.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Add(1)
m.InterceptionsInflight.WithLabelValues(providerName, interceptor.Model(), route).Add(1)
defer func() {
m.InterceptionsInflight.WithLabelValues(p.Name(), interceptor.Model(), route).Sub(1)
m.InterceptionsInflight.WithLabelValues(providerName, interceptor.Model(), route).Sub(1)
}()
}

// Process request with circuit breaker protection if configured
if err := cbs.Execute(route, interceptor.Model(), w, func(rw http.ResponseWriter) error {
if err := cbs.Execute(providerName, route, interceptor.Model(), w, func(rw http.ResponseWriter) error {
return interceptor.ProcessRequest(rw, r)
}); err != nil {
if m != nil {
m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID, string(client)).Add(1)
m.InterceptionCount.WithLabelValues(providerName, interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID, string(client)).Add(1)
}
span.SetStatus(codes.Error, fmt.Sprintf("interception failed: %v", err))
log.Warn(ctx, "interception failed", slog.Error(err))
} else {
if m != nil {
m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID, string(client)).Add(1)
m.InterceptionCount.WithLabelValues(providerName, interceptor.Model(), metrics.InterceptionCountStatusCompleted, route, r.Method, actor.ID, string(client)).Add(1)
}
log.Debug(ctx, "interception ended")
}
Expand Down
22 changes: 11 additions & 11 deletions circuitbreaker/circuitbreaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ func DefaultIsFailure(statusCode int) bool {
type ProviderCircuitBreakers struct {
provider string
config config.CircuitBreaker
breakers sync.Map // "endpoint:model" -> *gobreaker.CircuitBreaker[struct{}]
onChange func(endpoint, model string, from, to gobreaker.State)
breakers sync.Map // "providerName:endpoint:model" -> *gobreaker.CircuitBreaker[struct{}]
onChange func(providerName, endpoint, model string, from, to gobreaker.State)
metrics *metrics.Metrics
}

// NewProviderCircuitBreakers creates circuit breakers for a single provider.
// Returns nil if cfg is nil (no circuit breaker protection).
// onChange is called when circuit state changes.
// metrics is used to record circuit breaker reject counts (can be nil).
func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(endpoint, model string, from, to gobreaker.State), m *metrics.Metrics) *ProviderCircuitBreakers {
func NewProviderCircuitBreakers(provider string, cfg *config.CircuitBreaker, onChange func(providerName, endpoint, model string, from, to gobreaker.State), m *metrics.Metrics) *ProviderCircuitBreakers {
if cfg == nil {
return nil
}
Expand Down Expand Up @@ -71,15 +71,15 @@ func (p *ProviderCircuitBreakers) openErrorResponse() []byte {
return []byte(`{"error":"circuit breaker is open"}`)
}

// Get returns the circuit breaker for an endpoint/model tuple, creating it if needed.
func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.CircuitBreaker[struct{}] {
key := endpoint + ":" + model
// Get returns the circuit breaker for a providerName/endpoint/model tuple, creating it if needed.
func (p *ProviderCircuitBreakers) Get(providerName, endpoint, model string) *gobreaker.CircuitBreaker[struct{}] {
key := providerName + ":" + endpoint + ":" + model
if v, ok := p.breakers.Load(key); ok {
return v.(*gobreaker.CircuitBreaker[struct{}])
}

settings := gobreaker.Settings{
Name: p.provider + ":" + key,
Name: key,
MaxRequests: p.config.MaxRequests,
Interval: p.config.Interval,
Timeout: p.config.Timeout,
Expand All @@ -88,7 +88,7 @@ func (p *ProviderCircuitBreakers) Get(endpoint, model string) *gobreaker.Circuit
},
OnStateChange: func(_ string, from, to gobreaker.State) {
if p.onChange != nil {
p.onChange(endpoint, model, from, to)
p.onChange(providerName, endpoint, model, from, to)
}
},
}
Expand Down Expand Up @@ -139,12 +139,12 @@ func (w *statusCapturingWriter) Unwrap() http.ResponseWriter {
// Otherwise, it returns the handler's error (or nil on success).
// The handler receives a wrapped ResponseWriter that captures the status code.
// If the receiver is nil (no circuit breaker configured), the handler is called directly.
func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.ResponseWriter, handler func(http.ResponseWriter) error) error {
func (p *ProviderCircuitBreakers) Execute(providerName, endpoint, model string, w http.ResponseWriter, handler func(http.ResponseWriter) error) error {
if p == nil {
return handler(w)
}

cb := p.Get(endpoint, model)
cb := p.Get(providerName, endpoint, model)

// Wrap response writer to capture status code
sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK}
Expand All @@ -160,7 +160,7 @@ func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.Respons

if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) {
if p.metrics != nil {
p.metrics.CircuitBreakerRejects.WithLabelValues(p.provider, endpoint, model).Inc()
p.metrics.CircuitBreakerRejects.WithLabelValues(providerName, endpoint, model).Inc()
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds())))
Expand Down
47 changes: 25 additions & 22 deletions circuitbreaker/circuitbreaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ func TestExecute_PerModelIsolation(t *testing.T) {
Interval: time.Minute,
Timeout: time.Minute,
MaxRequests: 1,
}, func(endpoint, model string, from, to gobreaker.State) {}, nil)
}, func(providerName, endpoint, model string, from, to gobreaker.State) {}, nil)

endpoint := "/v1/messages"
sonnetModel := "claude-sonnet-4-20250514"
haikuModel := "claude-3-5-haiku-20241022"

// Trip circuit on sonnet model (returns 429)
w := httptest.NewRecorder()
err := cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) error {
err := cbs.Execute(config.ProviderAnthropic, endpoint, sonnetModel, w, func(rw http.ResponseWriter) error {
sonnetCalls.Add(1)
rw.WriteHeader(http.StatusTooManyRequests)
return nil
Expand All @@ -42,7 +42,7 @@ func TestExecute_PerModelIsolation(t *testing.T) {

// Second sonnet request should be blocked by circuit breaker
w = httptest.NewRecorder()
err = cbs.Execute(endpoint, sonnetModel, w, func(rw http.ResponseWriter) error {
err = cbs.Execute(config.ProviderAnthropic, endpoint, sonnetModel, w, func(rw http.ResponseWriter) error {
sonnetCalls.Add(1)
rw.WriteHeader(http.StatusOK)
return nil
Expand All @@ -53,7 +53,7 @@ func TestExecute_PerModelIsolation(t *testing.T) {

// Haiku model on same endpoint should still work (independent circuit)
w = httptest.NewRecorder()
err = cbs.Execute(endpoint, haikuModel, w, func(rw http.ResponseWriter) error {
err = cbs.Execute(config.ProviderAnthropic, endpoint, haikuModel, w, func(rw http.ResponseWriter) error {
haikuCalls.Add(1)
rw.WriteHeader(http.StatusOK)
return nil
Expand All @@ -73,13 +73,13 @@ func TestExecute_PerEndpointIsolation(t *testing.T) {
Interval: time.Minute,
Timeout: time.Minute,
MaxRequests: 1,
}, func(endpoint, model string, from, to gobreaker.State) {}, nil)
}, func(providerName, endpoint, model string, from, to gobreaker.State) {}, nil)

model := "test-model"

// Trip circuit on /v1/messages endpoint (returns 429)
w := httptest.NewRecorder()
err := cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) error {
err := cbs.Execute(config.ProviderAnthropic, "/v1/messages", model, w, func(rw http.ResponseWriter) error {
messagesCalls.Add(1)
rw.WriteHeader(http.StatusTooManyRequests)
return nil
Expand All @@ -89,7 +89,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) {

// Second /v1/messages request should be blocked
w = httptest.NewRecorder()
err = cbs.Execute("/v1/messages", model, w, func(rw http.ResponseWriter) error {
err = cbs.Execute(config.ProviderAnthropic, "/v1/messages", model, w, func(rw http.ResponseWriter) error {
messagesCalls.Add(1)
rw.WriteHeader(http.StatusOK)
return nil
Expand All @@ -100,7 +100,7 @@ func TestExecute_PerEndpointIsolation(t *testing.T) {

// /v1/chat/completions on same model should still work (different endpoint)
w = httptest.NewRecorder()
err = cbs.Execute("/v1/chat/completions", model, w, func(rw http.ResponseWriter) error {
err = cbs.Execute(config.ProviderOpenAI, "/v1/chat/completions", model, w, func(rw http.ResponseWriter) error {
completionsCalls.Add(1)
rw.WriteHeader(http.StatusOK)
return nil
Expand All @@ -123,11 +123,11 @@ func TestExecute_CustomIsFailure(t *testing.T) {
IsFailure: func(statusCode int) bool {
return statusCode == http.StatusBadGateway
},
}, func(endpoint, model string, from, to gobreaker.State) {}, nil)
}, func(providerName, endpoint, model string, from, to gobreaker.State) {}, nil)

// First request returns 502, trips circuit
w := httptest.NewRecorder()
err := cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) error {
err := cbs.Execute(config.ProviderAnthropic, "/v1/messages", "test-model", w, func(rw http.ResponseWriter) error {
calls.Add(1)
rw.WriteHeader(http.StatusBadGateway)
return nil
Expand All @@ -137,7 +137,7 @@ func TestExecute_CustomIsFailure(t *testing.T) {

// Second request should be blocked
w = httptest.NewRecorder()
err = cbs.Execute("/v1/messages", "test-model", w, func(rw http.ResponseWriter) error {
err = cbs.Execute(config.ProviderAnthropic, "/v1/messages", "test-model", w, func(rw http.ResponseWriter) error {
calls.Add(1)
rw.WriteHeader(http.StatusOK)
return nil
Expand All @@ -151,38 +151,41 @@ func TestExecute_OnStateChange(t *testing.T) {
t.Parallel()

var stateChanges []struct {
endpoint string
model string
from gobreaker.State
to gobreaker.State
providerName string
endpoint string
model string
from gobreaker.State
to gobreaker.State
}

cbs := NewProviderCircuitBreakers("test", &config.CircuitBreaker{
FailureThreshold: 1,
Interval: time.Minute,
Timeout: time.Minute,
MaxRequests: 1,
}, func(endpoint, model string, from, to gobreaker.State) {
}, func(providerName, endpoint, model string, from, to gobreaker.State) {
stateChanges = append(stateChanges, struct {
endpoint string
model string
from gobreaker.State
to gobreaker.State
}{endpoint, model, from, to})
providerName string
endpoint string
model string
from gobreaker.State
to gobreaker.State
}{providerName, endpoint, model, from, to})
}, nil)

endpoint := "/v1/messages"
model := "claude-sonnet-4-20250514"

// Trip circuit
w := httptest.NewRecorder()
cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error {
cbs.Execute(config.ProviderAnthropic, endpoint, model, w, func(rw http.ResponseWriter) error {
rw.WriteHeader(http.StatusTooManyRequests)
return nil
})

// Verify state change callback was called with correct parameters
assert.Len(t, stateChanges, 1)
assert.Equal(t, config.ProviderAnthropic, stateChanges[0].providerName)
assert.Equal(t, endpoint, stateChanges[0].endpoint)
assert.Equal(t, model, stateChanges[0].model)
assert.Equal(t, gobreaker.StateClosed, stateChanges[0].from)
Expand Down
15 changes: 10 additions & 5 deletions intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ import (
)

type interceptionBase struct {
id uuid.UUID
req *ChatCompletionNewParamsWrapper
cfg config.OpenAI
id uuid.UUID
providerName string
req *ChatCompletionNewParamsWrapper
cfg config.OpenAI

// clientHeaders are the original HTTP headers from the client request.
clientHeaders http.Header
Expand Down Expand Up @@ -62,7 +63,7 @@ func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService
}

// Add API dump middleware if configured
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.ProviderName(), i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
opts = append(opts, option.WithMiddleware(mw))
}

Expand All @@ -73,6 +74,10 @@ func (i *interceptionBase) ID() uuid.UUID {
return i.id
}

func (i *interceptionBase) ProviderName() string {
return i.providerName
}

func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.logger = logger
i.recorder = recorder
Expand All @@ -97,7 +102,7 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool)
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, s.id.String()),
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
attribute.String(tracing.Provider, config.ProviderOpenAI),
attribute.String(tracing.Provider, s.ProviderName()),
attribute.String(tracing.Model, s.Model()),
attribute.Bool(tracing.Streaming, streaming),
}
Expand Down
2 changes: 2 additions & 0 deletions intercept/chatcompletions/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ type BlockingInterception struct {
func NewBlockingInterceptor(
id uuid.UUID,
req *ChatCompletionNewParamsWrapper,
providerName string,
cfg config.OpenAI,
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
providerName: providerName,
req: req,
cfg: cfg,
clientHeaders: clientHeaders,
Expand Down
2 changes: 2 additions & 0 deletions intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ type StreamingInterception struct {
func NewStreamingInterceptor(
id uuid.UUID,
req *ChatCompletionNewParamsWrapper,
providerName string,
cfg config.OpenAI,
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
) *StreamingInterception {
return &StreamingInterception{interceptionBase: interceptionBase{
id: id,
providerName: providerName,
req: req,
cfg: cfg,
clientHeaders: clientHeaders,
Expand Down
2 changes: 1 addition & 1 deletion intercept/chatcompletions/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)

tracer := otel.Tracer("test")
interceptor := NewStreamingInterceptor(uuid.New(), req, cfg, httpReq.Header, "Authorization", tracer)
interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer)

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
Expand Down
2 changes: 2 additions & 0 deletions intercept/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type Interceptor interface {
Streaming() bool
// TraceAttributes returns tracing attributes for this [Interceptor]
TraceAttributes(*http.Request) []attribute.KeyValue
// ProviderName returns the provider name for this interception.
ProviderName() string
// CorrelatingToolCallID returns the ID of a tool call result submitted
// in the request, if present. This is used to correlate the current
// interception back to the previous interception that issued those tool
Expand Down
Loading
Loading