Skip to content
Open
5 changes: 5 additions & 0 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
asyncRecorder.WithClient(string(client))
interceptor.Setup(logger, asyncRecorder, mcpProxy)

cred := interceptor.Credential()
if err := rec.RecordInterception(ctx, &recorder.InterceptionRecord{
ID: interceptor.ID().String(),
InitiatorID: actor.ID,
Expand All @@ -228,6 +229,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
Client: string(client),
ClientSessionID: sessionID,
CorrelatingToolCallID: interceptor.CorrelatingToolCallID(),
CredentialKind: string(cred.Kind),
CredentialHint: cred.Hint,
}); err != nil {
span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err))
logger.Warn(ctx, "failed to record interception", slog.Error(err))
Expand All @@ -242,6 +245,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
slog.F("interception_id", interceptor.ID()),
slog.F("user_agent", r.UserAgent()),
slog.F("streaming", interceptor.Streaming()),
slog.F("credential_kind", string(cred.Kind)),
slog.F("credential_hint", cred.Hint),
)

log.Debug(ctx, "interception started")
Expand Down
9 changes: 7 additions & 2 deletions intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ type interceptionBase struct {
logger slog.Logger
tracer trace.Tracer

recorder recorder.Recorder
mcpProxy mcp.ServerProxier
recorder recorder.Recorder
mcpProxy mcp.ServerProxier
credential intercept.CredentialInfo
}

func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService {
Expand Down Expand Up @@ -74,6 +75,10 @@ func (i *interceptionBase) ID() uuid.UUID {
return i.id
}

func (i *interceptionBase) Credential() intercept.CredentialInfo {
return i.credential
}

func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.logger = logger
i.recorder = recorder
Expand Down
2 changes: 2 additions & 0 deletions intercept/chatcompletions/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func NewBlockingInterceptor(
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
Expand All @@ -45,6 +46,7 @@ func NewBlockingInterceptor(
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
}}
}

Expand Down
2 changes: 2 additions & 0 deletions intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func NewStreamingInterceptor(
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *StreamingInterception {
return &StreamingInterception{interceptionBase: interceptionBase{
id: id,
Expand All @@ -50,6 +51,7 @@ func NewStreamingInterceptor(
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
}}
}

Expand Down
3 changes: 2 additions & 1 deletion intercept/chatcompletions/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/aibridge/config"
"github.com/coder/aibridge/intercept"
"github.com/coder/aibridge/internal/testutil"
"github.com/google/uuid"
"github.com/openai/openai-go/v3"
Expand Down Expand Up @@ -86,7 +87,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)

tracer := otel.Tracer("test")
interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer)
interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer, intercept.CredentialInfo{})

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
Expand Down
30 changes: 30 additions & 0 deletions intercept/credential.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package intercept

import "github.com/coder/aibridge/utils"

// CredentialKind identifies how a request was authenticated.
// Keep in sync with the credential_kind enum in coderd's database.
type CredentialKind string

// Credential kind constants for interception recording.
const (
CredentialKindCentralized CredentialKind = "centralized"
CredentialKindPersonalAPIKey CredentialKind = "byok_api_key"
CredentialKindSubscription CredentialKind = "byok_subscription"
)

// CredentialInfo holds credential metadata for an interception.
type CredentialInfo struct {
Kind CredentialKind
Hint string
}

// NewCredentialInfo creates a CredentialInfo from a raw credential.
// The credential is automatically masked before storage so that the
// original secret is never retained.
func NewCredentialInfo(kind CredentialKind, credential string) CredentialInfo {
return CredentialInfo{
Kind: kind,
Hint: utils.MaskSecret(credential),
}
}
2 changes: 2 additions & 0 deletions intercept/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type Interceptor interface {
Streaming() bool
// TraceAttributes returns tracing attributes for this [Interceptor]
TraceAttributes(*http.Request) []attribute.KeyValue
// Credential returns the credential metadata for this interception.
Credential() CredentialInfo
// CorrelatingToolCallID returns the ID of a tool call result submitted
// in the request, if present. This is used to correlate the current
// interception back to the previous interception that issued those tool
Expand Down
9 changes: 7 additions & 2 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,19 @@ type interceptionBase struct {
tracer trace.Tracer
logger slog.Logger

recorder recorder.Recorder
mcpProxy mcp.ServerProxier
recorder recorder.Recorder
mcpProxy mcp.ServerProxier
credential intercept.CredentialInfo
}

func (i *interceptionBase) ID() uuid.UUID {
return i.id
}

func (i *interceptionBase) Credential() intercept.CredentialInfo {
return i.credential
}

func (i *interceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.logger = logger
i.recorder = recorder
Expand Down
2 changes: 2 additions & 0 deletions intercept/messages/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func NewBlockingInterceptor(
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
Expand All @@ -47,6 +48,7 @@ func NewBlockingInterceptor(
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
}}
}

Expand Down
2 changes: 2 additions & 0 deletions intercept/messages/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func NewStreamingInterceptor(
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *StreamingInterception {
return &StreamingInterception{interceptionBase: interceptionBase{
id: id,
Expand All @@ -53,6 +54,7 @@ func NewStreamingInterceptor(
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
}}
}

Expand Down
9 changes: 7 additions & 2 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ type responsesInterceptionBase struct {
recorder recorder.Recorder
mcpProxy mcp.ServerProxier

logger slog.Logger
tracer trace.Tracer
logger slog.Logger
tracer trace.Tracer
credential intercept.CredentialInfo
}

func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService {
Expand Down Expand Up @@ -83,6 +84,10 @@ func (i *responsesInterceptionBase) ID() uuid.UUID {
return i.id
}

func (i *responsesInterceptionBase) Credential() intercept.CredentialInfo {
return i.credential
}

func (i *responsesInterceptionBase) Setup(logger slog.Logger, recorder recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.logger = logger.With(slog.F("model", i.Model()))
i.recorder = recorder
Expand Down
2 changes: 2 additions & 0 deletions intercept/responses/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func NewBlockingInterceptor(
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *BlockingResponsesInterceptor {
return &BlockingResponsesInterceptor{
responsesInterceptionBase: responsesInterceptionBase{
Expand All @@ -43,6 +44,7 @@ func NewBlockingInterceptor(
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
},
}
}
Expand Down
2 changes: 2 additions & 0 deletions intercept/responses/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func NewStreamingInterceptor(
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *StreamingResponsesInterceptor {
return &StreamingResponsesInterceptor{
responsesInterceptionBase: responsesInterceptionBase{
Expand All @@ -50,6 +51,7 @@ func NewStreamingInterceptor(
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
},
}
}
Expand Down
12 changes: 10 additions & 2 deletions provider/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,21 +130,29 @@ func (p *Anthropic) CreateInterceptor(w http.ResponseWriter, r *http.Request, tr
// set BYOKBearerToken and clear the centralized key.
// When both are present, X-Api-Key takes priority to match
// claude-code behavior.
credKind := intercept.CredentialKindCentralized
credSecret := cfg.Key
authHeaderName := p.AuthHeader()
if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" {
cfg.Key = apiKey
authHeaderName = "X-Api-Key"
credKind = intercept.CredentialKindPersonalAPIKey
credSecret = apiKey
} else if token := utils.ExtractBearerToken(r.Header.Get("Authorization")); token != "" {
cfg.BYOKBearerToken = token
cfg.Key = ""
authHeaderName = "Authorization"
credKind = intercept.CredentialKindSubscription
credSecret = token
}

cred := intercept.NewCredentialInfo(credKind, credSecret)

var interceptor intercept.Interceptor
if reqPayload.Stream() {
interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer)
interceptor = messages.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred)
} else {
interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer)
interceptor = messages.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, p.bedrockCfg, r.Header, authHeaderName, tracer, cred)
}
span.SetAttributes(interceptor.TraceAttributes(r)...)
return interceptor, nil
Expand Down
43 changes: 29 additions & 14 deletions provider/anthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/coder/aibridge/config"
"github.com/coder/aibridge/intercept"
"github.com/coder/aibridge/internal/testutil"
)

Expand Down Expand Up @@ -163,33 +164,43 @@ func TestAnthropic_CreateInterceptor_BYOK(t *testing.T) {
t.Parallel()

tests := []struct {
name string
setHeaders map[string]string
wantXApiKey string
wantAuthorization string
name string
setHeaders map[string]string
wantXApiKey string
wantAuthorization string
wantCredentialKind intercept.CredentialKind
wantCredentialHint string
}{
{
name: "Messages_BYOK_BearerToken",
setHeaders: map[string]string{"Authorization": "Bearer user-access-token"},
wantAuthorization: "Bearer user-access-token",
name: "Messages_BYOK_BearerToken",
setHeaders: map[string]string{"Authorization": "Bearer user-access-token"},
wantAuthorization: "Bearer user-access-token",
wantCredentialKind: intercept.CredentialKindSubscription,
wantCredentialHint: "us*************en",
},
{
name: "Messages_BYOK_APIKey",
setHeaders: map[string]string{"X-Api-Key": "user-api-key"},
wantXApiKey: "user-api-key",
name: "Messages_BYOK_APIKey",
setHeaders: map[string]string{"X-Api-Key": "user-api-key"},
wantXApiKey: "user-api-key",
wantCredentialKind: intercept.CredentialKindPersonalAPIKey,
wantCredentialHint: "us********ey",
},
{
name: "Messages_Centralized_UsesCentralizedKey",
setHeaders: map[string]string{},
wantXApiKey: "test-key",
name: "Messages_Centralized_UsesCentralizedKey",
setHeaders: map[string]string{},
wantXApiKey: "test-key",
wantCredentialKind: intercept.CredentialKindCentralized,
wantCredentialHint: "********",
},
{
name: "Messages_BYOK_BearerToken_And_APIKey",
setHeaders: map[string]string{
"Authorization": "Bearer user-access-token",
"X-Api-Key": "user-api-key",
},
wantXApiKey: "user-api-key",
wantXApiKey: "user-api-key",
wantCredentialKind: intercept.CredentialKindPersonalAPIKey,
wantCredentialHint: "us********ey",
},
}

Expand Down Expand Up @@ -223,6 +234,10 @@ func TestAnthropic_CreateInterceptor_BYOK(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, interceptor)

cred := interceptor.Credential()
assert.Equal(t, tc.wantCredentialKind, cred.Kind, "credential kind mismatch")
assert.Equal(t, tc.wantCredentialHint, cred.Hint, "credential hint mismatch")

logger := slog.Make()
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)

Expand Down
10 changes: 6 additions & 4 deletions provider/copilot.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
ExtraHeaders: extractCopilotHeaders(r),
}

cred := intercept.NewCredentialInfo(intercept.CredentialKindSubscription, key)

var interceptor intercept.Interceptor

path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix())
Expand All @@ -156,9 +158,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
}

if req.Stream {
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer, cred)
} else {
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer, cred)
}

case routeCopilotResponses:
Expand All @@ -172,9 +174,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
}

if reqPayload.Stream() {
interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer, cred)
} else {
interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer)
interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer, cred)
}

default:
Expand Down
Loading
Loading