diff --git a/bridge.go b/bridge.go index 4d79fba..e881a7c 100644 --- a/bridge.go +++ b/bridge.go @@ -13,6 +13,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/aibridge/circuitbreaker" + "github.com/google/uuid" aibcontext "github.com/coder/aibridge/context" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/metrics" @@ -196,6 +197,34 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC }() } + // For Coder Agents, the X-Coder-Owner-Id header identifies the actual + // user who initiated the chat. Override the actor so usage is attributed + // to the correct user rather than the service-level identity. + if client == ClientCoderAgents { + if ownerID := r.Header.Get("X-Coder-Owner-Id"); ownerID != "" { + if _, err := uuid.Parse(ownerID); err != nil { + logger.Warn(ctx, "ignoring invalid X-Coder-Owner-Id, expected UUID", + slog.F("value", ownerID), + slog.Error(err), + ) + } else { + existingActor := aibcontext.ActorFromContext(ctx) + var md recorder.Metadata + var previousActorID string + if existingActor != nil { + md = existingActor.Metadata + previousActorID = existingActor.ID + } + logger.Debug(ctx, "overriding initiator with X-Coder-Owner-Id", + slog.F("previous_actor_id", previousActorID), + slog.F("new_actor_id", ownerID), + ) + ctx = aibcontext.AsActor(ctx, ownerID, md) + r = r.WithContext(ctx) + } + } + } + actor := aibcontext.ActorFromContext(ctx) if actor == nil { logger.Warn(ctx, "no actor found in context") diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index 4f5f72d..3eaa90f 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -12,6 +12,8 @@ import ( "testing" "time" + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/sloghuman" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/anthropics/anthropic-sdk-go/shared/constant" @@ -2090,3 +2092,82 @@ func TestActorHeaders(t *testing.T) { } } } + +func TestCoderAgentsInitiatorOverride(t *testing.T) { + t.Parallel() + + const overrideActorID = "b1c2d3e4-5678-4a9b-8c0d-1e2f3a4b5c6d" + + cases := []struct { + name string + userAgent string + ownerIDHeader string + expectInitiator string + expectLogOverride bool + }{ + { + name: "coder_agents_with_owner_id", + userAgent: "coder-agents/v2.24.0 (linux/amd64)", + ownerIDHeader: overrideActorID, + expectInitiator: overrideActorID, + expectLogOverride: true, + }, + { + name: "coder_agents_without_owner_id", + userAgent: "coder-agents/v2.24.0 (linux/amd64)", + ownerIDHeader: "", + expectInitiator: defaultActorID, + }, + { + name: "coder_agents_with_invalid_owner_id", + userAgent: "coder-agents/v2.24.0 (linux/amd64)", + ownerIDHeader: "not-a-uuid", + expectInitiator: defaultActorID, + }, + { + name: "non_coder_agents_with_owner_id_header", + userAgent: "claude-code/1.0.0", + ownerIDHeader: overrideActorID, + expectInitiator: defaultActorID, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + fix := fixtures.Parse(t, fixtures.AntSimple) + upstream := newMockUpstream(t, ctx, newFixtureResponse(fix)) + + var logBuf bytes.Buffer + logger := slog.Make(sloghuman.Sink(&logBuf)).Leveled(slog.LevelDebug) + bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, withLogger(logger)) + + headers := http.Header{"User-Agent": {tc.userAgent}} + if tc.ownerIDHeader != "" { + headers.Set("X-Coder-Owner-Id", tc.ownerIDHeader) + } + + resp := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, fix.Request(), headers) + require.Equal(t, http.StatusOK, resp.StatusCode) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + interceptions := bridgeServer.Recorder.RecordedInterceptions() + require.Len(t, interceptions, 1) + require.Equal(t, tc.expectInitiator, interceptions[0].InitiatorID) + + logOutput := logBuf.String() + if tc.expectLogOverride { + assert.Contains(t, logOutput, "overriding initiator with X-Coder-Owner-Id") + assert.Contains(t, logOutput, defaultActorID) + assert.Contains(t, logOutput, overrideActorID) + } else { + assert.NotContains(t, logOutput, "overriding initiator with X-Coder-Owner-Id") + } + }) + } +} diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index bb999d2..effbdc3 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -51,6 +51,7 @@ type bridgeConfig struct { userID string metadata recorder.Metadata logger slog.Logger + loggerSet bool } // bridgeTestServer wraps an httptest.Server running a RequestBridge. @@ -119,6 +120,11 @@ func withMCP(p mcp.ServerProxier) bridgeOption { return func(c *bridgeConfig) { c.mcpProxy = p } } +// withLogger overrides the default test logger. +func withLogger(l slog.Logger) bridgeOption { + return func(c *bridgeConfig) { c.logger = l; c.loggerSet = true } +} + // withActor sets the actor ID and metadata for the BaseContext. func withActor(id string, md recorder.Metadata) bridgeOption { return func(c *bridgeConfig) { c.userID = id; c.metadata = md } @@ -148,7 +154,9 @@ func newBridgeTestServer( if cfg.tracer == nil { cfg.tracer = defaultTracer } - cfg.logger = newLogger(t) + if !cfg.loggerSet { + cfg.logger = newLogger(t) + } if cfg.mcpProxy == nil { cfg.mcpProxy = newNoopMCPManager() }