diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index 4e9c5bfe1..bc604363b 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -74,8 +74,16 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro slog.ErrorContext(ctx, "Anthropic client creation failed", "error", err) return nil, err } + httpClient := httpclient.NewHTTPClient(ctx) + if w := globalOptions.TransportWrapper(); w != nil { + if wrapped := w(httpClient.Transport); wrapped != nil { + httpClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } requestOptions := append([]option.RequestOption{ - option.WithHTTPClient(httpclient.NewHTTPClient(ctx)), + option.WithHTTPClient(httpClient), }, authOpts...) if cfg.BaseURL != "" { requestOptions = append(requestOptions, option.WithBaseURL(cfg.BaseURL)) @@ -127,9 +135,17 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro httpOptions = append(httpOptions, httpclient.WithHeader("X-Cagent-GeneratingTitle", "1")) } + gatewayHTTPClient := httpclient.NewHTTPClient(ctx, httpOptions...) + if w := globalOptions.TransportWrapper(); w != nil { + if wrapped := w(gatewayHTTPClient.Transport); wrapped != nil { + gatewayHTTPClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } clientOptions := []option.RequestOption{ option.WithBaseURL(baseURL), - option.WithHTTPClient(httpclient.NewHTTPClient(ctx, httpOptions...)), + option.WithHTTPClient(gatewayHTTPClient), } if authToken != "" { clientOptions = append(clientOptions, option.WithAuthToken(authToken), option.WithAPIKey(authToken)) diff --git a/pkg/model/provider/anthropic/wrap_transport_test.go b/pkg/model/provider/anthropic/wrap_transport_test.go new file mode 100644 index 000000000..5eac2140e --- /dev/null +++ b/pkg/model/provider/anthropic/wrap_transport_test.go @@ -0,0 +1,157 @@ +package anthropic + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/model/provider/options" +) + +// countingTransport wraps a base RoundTripper and counts how many times +// RoundTrip is called. +type countingTransport struct { + base http.RoundTripper + calls atomic.Int64 +} + +func (c *countingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + c.calls.Add(1) + return c.base.RoundTrip(req) +} + +// writeMinimalAnthropicSSE writes a bare-minimum valid Anthropic SSE stream +// so that the streaming client does not error before we can observe transport invocation. +func writeMinimalAnthropicSSE(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + + writeEvent := func(eventType string, payload any) { + data, _ := json.Marshal(payload) + _, _ = w.Write([]byte("event: " + eventType + "\n")) + _, _ = w.Write([]byte("data: " + string(data) + "\n\n")) + if flusher != nil { + flusher.Flush() + } + } + + writeEvent("message_start", map[string]any{ + "type": "message_start", + "message": map[string]any{"id": "msg_test", "model": "claude-test", "role": "assistant", "type": "message", "content": []any{}, "stop_reason": nil, "usage": map[string]any{"input_tokens": 5, "output_tokens": 0}}, + }) + writeEvent("content_block_start", map[string]any{ + "type": "content_block_start", + "index": 0, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + writeEvent("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": 0, + "delta": map[string]any{"type": "text_delta", "text": "hi"}, + }) + writeEvent("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": 0, + }) + writeEvent("message_delta", map[string]any{ + "type": "message_delta", + "delta": map[string]any{"stop_reason": "end_turn", "stop_sequence": nil}, + "usage": map[string]any{"output_tokens": 1}, + }) + writeEvent("message_stop", map[string]any{"type": "message_stop"}) +} + +func TestNewClient_TransportWrapperInvokedDirectPath(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeMinimalAnthropicSSE(w) + })) + defer server.Close() + + var counter countingTransport + + cfg := &latest.ModelConfig{ + Provider: "anthropic", + Model: "claude-3-5-haiku-latest", + BaseURL: server.URL, + } + env := environment.NewMapEnvProvider(map[string]string{ + "ANTHROPIC_API_KEY": "test-key", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip has been fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once") +} + +func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeMinimalAnthropicSSE(w) + })) + defer server.Close() + + var counter countingTransport + + cfg := &latest.ModelConfig{ + Provider: "anthropic", + Model: "claude-3-5-haiku-latest", + } + // server.URL is 127.0.0.1 which IsTrustedDockerURL considers trusted, + // so we must supply the Docker Desktop token. + env := environment.NewMapEnvProvider(map[string]string{ + environment.DockerDesktopTokenEnv: "test-dd-token", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithGateway(server.URL), + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip has been fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once in gateway path") +} diff --git a/pkg/model/provider/bedrock/client.go b/pkg/model/provider/bedrock/client.go index 282ef56f2..fb6bcce1e 100644 --- a/pkg/model/provider/bedrock/client.go +++ b/pkg/model/provider/bedrock/client.go @@ -18,6 +18,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/httpclient" "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/model/provider/providerutil" @@ -60,7 +61,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro opt(&globalOptions) } - // Check for bearer token - use token_key if specified, otherwise try AWS_BEARER_TOKEN_BEDROCK. + // Check for bearer token // Bearer token is optional: if not provided, falls back to standard AWS credential chain (SigV4). // // NOTE: Manual token handling is required because aws-sdk-go-v2's default credential chain @@ -77,6 +78,28 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro bearerToken, _ = env.Get(ctx, "AWS_BEARER_TOKEN_BEDROCK") } + // Build the docker-agent HTTP client (OTel instrumentation, SSE decompression, + // Desktop proxy support) so transport-level concerns apply to Bedrock too. + httpClient := httpclient.NewHTTPClient(ctx) + + // If a bearer token is set, chain it on top of the base transport so auth + // headers are injected without replacing the rest of the transport stack. + if bearerToken != "" { + httpClient.Transport = &bearerTokenTransport{ + token: bearerToken, + base: httpClient.Transport, + } + } + + // Apply the transport wrapper, if registered, over the full chain. + if w := globalOptions.TransportWrapper(); w != nil { + if wrapped := w(httpClient.Transport); wrapped != nil { + httpClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } + // Build AWS config using default credential chain awsCfg, err := buildAWSConfig(ctx, cfg, env) if err != nil { @@ -94,19 +117,18 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro }) } - // If bearer token is set, use it instead of SigV4 + // Inject our HTTP client (which carries OTel, SSE, bearer token if set, and + // any caller-registered transport wrapper) into the Bedrock runtime options. if bearerToken != "" { slog.DebugContext(ctx, "Bedrock using bearer token authentication") clientOpts = append(clientOpts, func(o *bedrockruntime.Options) { // Use anonymous credentials to skip SigV4 signing o.Credentials = aws.AnonymousCredentials{} - // Add bearer token via custom HTTP client - o.HTTPClient = &http.Client{ - Transport: &bearerTokenTransport{ - token: bearerToken, - base: http.DefaultTransport, - }, - } + o.HTTPClient = httpClient + }) + } else { + clientOpts = append(clientOpts, func(o *bedrockruntime.Options) { + o.HTTPClient = httpClient }) } diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index 593308359..c7e2e4394 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -61,8 +61,20 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro project string location string ) - // project/location take priority over API key, like in the genai client. - if cfg.ProviderOpts["project"] != nil || cfg.ProviderOpts["location"] != nil { + // Determine whether Vertex AI would normally be used, then check whether + // an HTTP transport wrapper forces a fallback to BackendGeminiAPI. + // The Vertex AI backend relies on ADC-managed HTTP clients that bypass + // http.RoundTripper, so the wrapper cannot be applied there. + _, useVertexAIEnv := env.Get(ctx, "GOOGLE_GENAI_USE_VERTEXAI") + wantVertexAI := cfg.ProviderOpts["project"] != nil || cfg.ProviderOpts["location"] != nil || useVertexAIEnv + useVertexAI := wantVertexAI && globalOptions.TransportWrapper() == nil + + if wantVertexAI && !useVertexAI { + slog.DebugContext(ctx, "Vertex AI requested but HTTP transport wrapper is set, falling back to GeminiAPI backend") + } + + switch { + case useVertexAI && (cfg.ProviderOpts["project"] != nil || cfg.ProviderOpts["location"] != nil): var err error project, err = environment.Expand(ctx, providerOption(cfg, "project"), env) @@ -82,13 +94,13 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro } backend = genai.BackendVertexAI - httpClient = nil // Use default client - } else if _, exist := env.Get(ctx, "GOOGLE_GENAI_USE_VERTEXAI"); exist { + httpClient = nil // Use ADC-managed client + case useVertexAI: project, _ = env.Get(ctx, "GOOGLE_CLOUD_PROJECT") location, _ = env.Get(ctx, "GOOGLE_CLOUD_LOCATION") backend = genai.BackendVertexAI - httpClient = nil // Use default client - } else { + httpClient = nil // Use ADC-managed client + default: if value, exist := env.Get(ctx, "GEMINI_API_KEY"); exist { apiKey = value } @@ -103,6 +115,14 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro httpClient = httpclient.NewHTTPClient(ctx) } + if w := globalOptions.TransportWrapper(); w != nil { + if wrapped := w(httpClient.Transport); wrapped != nil { + httpClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } + client, err := genai.NewClient(ctx, &genai.ClientConfig{ APIKey: apiKey, Project: project, @@ -168,10 +188,19 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro } } + gatewayHTTPClient := httpclient.NewHTTPClient(ctx, httpOptions...) + if w := globalOptions.TransportWrapper(); w != nil { + if wrapped := w(gatewayHTTPClient.Transport); wrapped != nil { + gatewayHTTPClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } + return genai.NewClient(ctx, &genai.ClientConfig{ APIKey: authToken, Backend: genai.BackendGeminiAPI, - HTTPClient: httpclient.NewHTTPClient(ctx, httpOptions...), + HTTPClient: gatewayHTTPClient, HTTPOptions: httpOpts, }) } diff --git a/pkg/model/provider/gemini/wrap_transport_test.go b/pkg/model/provider/gemini/wrap_transport_test.go new file mode 100644 index 000000000..ee524e4f5 --- /dev/null +++ b/pkg/model/provider/gemini/wrap_transport_test.go @@ -0,0 +1,174 @@ +package gemini + +import ( + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/model/provider/options" +) + +// geminiCountingTransport wraps a base RoundTripper and counts RoundTrip calls. +type geminiCountingTransport struct { + base http.RoundTripper + calls atomic.Int64 +} + +func (c *geminiCountingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + c.calls.Add(1) + return c.base.RoundTrip(req) +} + +// writeGeminiSSEResponse writes a minimal valid Gemini streaming response. +func writeGeminiSSEResponse(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + + payload := `{"candidates":[{"content":{"parts":[{"text":"hi"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}` + _, _ = fmt.Fprintf(w, "data: %s\n\n", payload) + if flusher != nil { + flusher.Flush() + } +} + +func TestNewClient_TransportWrapperInvokedDirectPath(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeGeminiSSEResponse(w) + })) + defer server.Close() + + var counter geminiCountingTransport + + cfg := &latest.ModelConfig{ + Provider: "google", + Model: "gemini-2.0-flash", + BaseURL: server.URL, + } + env := environment.NewMapEnvProvider(map[string]string{ + "GOOGLE_API_KEY": "test-key", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip has been fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once") +} + +func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeGeminiSSEResponse(w) + })) + defer server.Close() + + var counter geminiCountingTransport + + cfg := &latest.ModelConfig{ + Provider: "google", + Model: "gemini-2.0-flash", + } + // server.URL is 127.0.0.1 which IsTrustedDockerURL considers trusted, + // so we must supply the Docker Desktop token. + env := environment.NewMapEnvProvider(map[string]string{ + environment.DockerDesktopTokenEnv: "test-dd-token", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithGateway(server.URL), + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip has been fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once in gateway path") +} + +// TestNewClient_TransportWrapperVertexAIFallsBackToGeminiAPI verifies that when +// project/location are configured (Vertex AI) but a transport wrapper is also +// set, the client automatically falls back to BackendGeminiAPI so the wrapper +// can be applied. This mirrors the WebSocket→SSE fallback in the OpenAI provider. +func TestNewClient_TransportWrapperVertexAIFallsBackToGeminiAPI(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeGeminiSSEResponse(w) + })) + defer server.Close() + + var counter geminiCountingTransport + + cfg := &latest.ModelConfig{ + Provider: "google", + Model: "gemini-2.0-flash", + BaseURL: server.URL, + ProviderOpts: map[string]any{ + "project": "test-project", + "location": "us-central1", + }, + } + // GOOGLE_API_KEY is required for the GeminiAPI fallback path. + env := environment.NewMapEnvProvider(map[string]string{ + "GOOGLE_API_KEY": "test-key", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip has been fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once (GeminiAPI fallback)") +} diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index 3c5a54a98..430d5d8aa 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -103,6 +103,13 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro clientOptions = append(clientOptions, option.WithMiddleware(oaistream.ErrorBodyMiddleware())) httpClient := httpclient.NewHTTPClient(ctx) + if w := globalOptions.TransportWrapper(); w != nil { + if wrapped := w(httpClient.Transport); wrapped != nil { + httpClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } clientOptions = append(clientOptions, option.WithHTTPClient(httpClient)) client := openai.NewClient(clientOptions...) @@ -149,9 +156,17 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro httpOptions = append(httpOptions, httpclient.WithHeader("X-Cagent-GeneratingTitle", "1")) } + gatewayHTTPClient := httpclient.NewHTTPClient(ctx, httpOptions...) + if w := globalOptions.TransportWrapper(); w != nil { + if wrapped := w(gatewayHTTPClient.Transport); wrapped != nil { + gatewayHTTPClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } clientOptions := []option.RequestOption{ option.WithBaseURL(baseURL), - option.WithHTTPClient(httpclient.NewHTTPClient(ctx, httpOptions...)), + option.WithHTTPClient(gatewayHTTPClient), option.WithMiddleware(oaistream.ErrorBodyMiddleware()), } if authToken != "" { @@ -177,7 +192,10 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro // Pre-create the WebSocket pool when the transport is configured. // The pool is cheap (no connections opened until the first Stream call) // and eager init avoids a data race on the lazy path. - if getTransport(cfg) == "websocket" && globalOptions.Gateway() == "" { + // WebSocket is also skipped when an HTTP transport wrapper is registered: + // gorilla/websocket dials raw TCP and never calls http.RoundTripper, so the + // wrapper cannot be applied. Fall back to SSE so the wrapper covers all calls. + if getTransport(cfg) == "websocket" && globalOptions.Gateway() == "" && globalOptions.TransportWrapper() == nil { baseURL := cmp.Or(cfg.BaseURL, "https://api.openai.com/v1") client.wsPool = newWSPool(httpToWSURL(baseURL), client.buildWSHeaderFn()) } @@ -512,10 +530,14 @@ func (c *Client) CreateResponseStream( // Choose transport: WebSocket or SSE (default). // WebSocket is disabled when using a Gateway since most gateways don't support it. + // WebSocket is also disabled when an HTTP transport wrapper is registered: gorilla/websocket + // dials raw TCP and never calls http.RoundTripper, so the wrapper cannot intercept those + // connections. Fall back to SSE so the wrapper applies to all requests. transport := getTransport(&c.ModelConfig) trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage - if transport == "websocket" && c.ModelOptions.Gateway() == "" { + switch { + case transport == "websocket" && c.ModelOptions.Gateway() == "" && c.ModelOptions.TransportWrapper() == nil: stream, err := c.createWebSocketStream(ctx, params) if err != nil { slog.WarnContext(ctx, "WebSocket stream failed, falling back to SSE", "error", err) @@ -524,10 +546,13 @@ func (c *Client) CreateResponseStream( slog.DebugContext(ctx, "OpenAI responses WebSocket stream created successfully", "model", c.ModelConfig.Model) return newResponseStreamAdapter(stream, trackUsage), nil } - } else if transport == "websocket" { + case transport == "websocket" && c.ModelOptions.Gateway() != "": slog.DebugContext(ctx, "WebSocket transport requested but Gateway is configured, using SSE", "model", c.ModelConfig.Model, "gateway", c.ModelOptions.Gateway()) + case transport == "websocket": + slog.DebugContext(ctx, "WebSocket transport requested but HTTP transport wrapper is set, using SSE", + "model", c.ModelConfig.Model) } client, err := c.clientFn(ctx) diff --git a/pkg/model/provider/openai/wrap_transport_test.go b/pkg/model/provider/openai/wrap_transport_test.go new file mode 100644 index 000000000..95d4afdd8 --- /dev/null +++ b/pkg/model/provider/openai/wrap_transport_test.go @@ -0,0 +1,148 @@ +package openai + +import ( + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/model/provider/options" +) + +// oaiCountingTransport wraps a base RoundTripper and counts RoundTrip calls. +type oaiCountingTransport struct { + base http.RoundTripper + calls atomic.Int64 +} + +func (c *oaiCountingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + c.calls.Add(1) + return c.base.RoundTrip(req) +} + +func TestNewClient_TransportWrapperInvokedDirectPath(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeSSEResponse(w) + })) + defer server.Close() + + var counter oaiCountingTransport + + cfg := &latest.ModelConfig{ + Provider: "openai", + Model: "gpt-4o", + BaseURL: server.URL, + TokenKey: "OPENAI_API_KEY", + } + env := environment.NewMapEnvProvider(map[string]string{ + "OPENAI_API_KEY": "test-key", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip is fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once") +} + +func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeSSEResponse(w) + })) + defer server.Close() + + var counter oaiCountingTransport + + cfg := &latest.ModelConfig{ + Provider: "openai", + Model: "gpt-4o", + } + // server.URL is 127.0.0.1 which IsTrustedDockerURL considers trusted, + // so we must supply the Docker Desktop token. + env := environment.NewMapEnvProvider(map[string]string{ + environment.DockerDesktopTokenEnv: "test-dd-token", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithGateway(server.URL), + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip is fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once in gateway path") +} + +// TestNewClient_WebSocketFallsBackToSSEWhenTransportWrapperSet verifies that +// configuring transport=websocket together with a transport wrapper causes the +// client to fall back to SSE (no wsPool created). gorilla/websocket bypasses +// http.RoundTripper, so websocket would silently drop the wrapper; the fallback +// ensures the wrapper covers every outbound request. +func TestNewClient_WebSocketFallsBackToSSEWhenTransportWrapperSet(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeSSEResponse(w) + })) + defer server.Close() + + var counter oaiCountingTransport + + cfg := &latest.ModelConfig{ + Provider: "openai", + Model: "gpt-4o-realtime-preview", + BaseURL: server.URL, + TokenKey: "OPENAI_API_KEY", + ProviderOpts: map[string]any{"transport": "websocket"}, + } + env := environment.NewMapEnvProvider(map[string]string{ + "OPENAI_API_KEY": "test-key", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + // wsPool must not be created when a transport wrapper is registered. + assert.Nil(t, client.wsPool, "wsPool should be nil when a transport wrapper is set") +} diff --git a/pkg/model/provider/options/options.go b/pkg/model/provider/options/options.go index 329e2b276..d9898a85a 100644 --- a/pkg/model/provider/options/options.go +++ b/pkg/model/provider/options/options.go @@ -1,6 +1,8 @@ package options import ( + "net/http" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/modelsdev" ) @@ -13,6 +15,7 @@ type ModelOptions struct { maxTokens int64 providers map[string]latest.ProviderConfig modelsDevStore *modelsdev.Store + transportWrapper func(http.RoundTripper) http.RoundTripper } func (c *ModelOptions) Gateway() string { @@ -43,6 +46,12 @@ func (c *ModelOptions) ModelsDevStore() *modelsdev.Store { return c.modelsDevStore } +// TransportWrapper returns the HTTP transport wrapper function registered via +// WithHTTPTransportWrapper, or nil if none was set. +func (c *ModelOptions) TransportWrapper() func(http.RoundTripper) http.RoundTripper { + return c.transportWrapper +} + type Opt func(*ModelOptions) func WithGateway(gateway string) Opt { @@ -87,6 +96,40 @@ func WithModelsDevStore(store *modelsdev.Store) Opt { } } +// WithHTTPTransportWrapper registers a function that wraps the HTTP transport +// used by provider clients (Anthropic, OpenAI, and Gemini with the Gemini API +// backend). The function receives the transport that docker-agent built +// (including OTel instrumentation, SSE decompression fix, and Desktop proxy +// support) and must return a new RoundTripper that delegates to it. The wrapper +// is applied in both direct mode and gateway/proxy mode. +// +// Call-frequency note: in direct mode the wrapper is invoked once at client +// construction time; in gateway mode it is invoked on every LLM request +// (because gateway clients are rebuilt on each call to refresh short-lived +// auth tokens). Wrappers with per-call side effects (metrics, token rotation) +// will therefore be called more frequently in gateway mode. +// +// Limitations: +// - OpenAI clients configured with transport=websocket bypass the HTTP +// transport layer entirely; the wrapper is not applied in that mode. +// - Gemini clients using the Vertex AI backend (project/location config or +// GOOGLE_GENAI_USE_VERTEXAI) rely on the genai SDK's default HTTP client; +// the wrapper is not applied and a warning is logged. +// +// The wrapper function must return a non-nil RoundTripper; returning nil is a +// no-op (a warning is logged and the original transport is kept). +// +// Example — inject a bearer token on every outbound LLM request: +// +// options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { +// return &bearerTransport{token: myToken, base: base} +// }) +func WithHTTPTransportWrapper(fn func(base http.RoundTripper) http.RoundTripper) Opt { + return func(cfg *ModelOptions) { + cfg.transportWrapper = fn + } +} + // FromModelOptions converts a concrete ModelOptions value into a slice of // Opt configuration functions. Later Opts override earlier ones when applied. func FromModelOptions(m ModelOptions) []Opt { @@ -112,5 +155,8 @@ func FromModelOptions(m ModelOptions) []Opt { if m.modelsDevStore != nil { out = append(out, WithModelsDevStore(m.modelsDevStore)) } + if m.transportWrapper != nil { + out = append(out, WithHTTPTransportWrapper(m.transportWrapper)) + } return out } diff --git a/pkg/model/provider/options/options_test.go b/pkg/model/provider/options/options_test.go new file mode 100644 index 000000000..84f5e1479 --- /dev/null +++ b/pkg/model/provider/options/options_test.go @@ -0,0 +1,80 @@ +package options + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// sentinelTransport is a minimal http.RoundTripper used only for identity checks. +type sentinelTransport struct{ base http.RoundTripper } + +func (s *sentinelTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return s.base.RoundTrip(req) +} + +func TestWithHTTPTransportWrapper_SetAndGet(t *testing.T) { + var called bool + wrapFn := func(base http.RoundTripper) http.RoundTripper { + called = true + return &sentinelTransport{base: base} + } + + var opts ModelOptions + WithHTTPTransportWrapper(wrapFn)(&opts) + + got := opts.TransportWrapper() + require.NotNil(t, got) + + // Verify invoking the returned wrapper marks called=true and returns a non-nil transport. + result := got(http.DefaultTransport) + assert.True(t, called) + assert.NotNil(t, result) +} + +func TestTransportWrapper_NilByDefault(t *testing.T) { + var opts ModelOptions + assert.Nil(t, opts.TransportWrapper()) +} + +func TestFromModelOptions_RoundTripsTransportWrapper(t *testing.T) { + var wrapperInvoked bool + wrapFn := func(base http.RoundTripper) http.RoundTripper { + wrapperInvoked = true + return &sentinelTransport{base: base} + } + + var src ModelOptions + WithHTTPTransportWrapper(wrapFn)(&src) + + opts := FromModelOptions(src) + require.NotEmpty(t, opts) + + var dst ModelOptions + for _, o := range opts { + o(&dst) + } + + got := dst.TransportWrapper() + require.NotNil(t, got) + + result := got(http.DefaultTransport) + assert.True(t, wrapperInvoked) + assert.NotNil(t, result) +} + +func TestFromModelOptions_NilWrapperNotIncluded(t *testing.T) { + // A ModelOptions with no transport wrapper should not add a + // WithHTTPTransportWrapper opt, so TransportWrapper() stays nil. + var src ModelOptions + opts := FromModelOptions(src) + + var dst ModelOptions + for _, o := range opts { + o(&dst) + } + + assert.Nil(t, dst.TransportWrapper()) +}