From 5dcda0f88db3af3491a821c93296ad1d8ca0ce58 Mon Sep 17 00:00:00 2001 From: Simon Ferquel's Clanker Date: Fri, 12 Jun 2026 13:06:28 +0000 Subject: [PATCH 1/7] feat: add options.WithHTTPTransportWrapper to inject HTTP middleware in provider clients Adds a new Opt that allows callers embedding docker-agent programmatically to wrap the HTTP transport used by all provider clients. The wrapper function receives the already-built transport (post-OTel, post-SSE-filter, post-Desktop-proxy) and must return a new RoundTripper that delegates to it. Changes: - options.ModelOptions: new transportWrapper field, WithHTTPTransportWrapper Opt, TransportWrapper() accessor, and FromModelOptions round-trip support - Anthropic provider: wrapper applied in direct and gateway paths - OpenAI provider: wrapper applied in direct and gateway paths - Gemini provider: wrapper applied in direct and gateway paths (GeminiAPI backend); Vertex AI path logs a warning (AWS SDK manages its own client) - Bedrock provider: logs a warning (AWS SDK manages its own client) - All four active call sites guard against a nil return from the wrapper (original transport preserved, slog.Warn emitted) - Tests: unit tests in options package; direct and gateway integration tests in anthropic, openai, and gemini provider packages Fixes #3089 --- pkg/model/provider/anthropic/client.go | 20 ++- .../provider/anthropic/wrap_transport_test.go | 157 ++++++++++++++++++ pkg/model/provider/bedrock/client.go | 4 + pkg/model/provider/gemini/client.go | 21 ++- .../provider/gemini/wrap_transport_test.go | 153 +++++++++++++++++ pkg/model/provider/openai/client.go | 17 +- .../provider/openai/wrap_transport_test.go | 112 +++++++++++++ pkg/model/provider/options/options.go | 48 ++++++ pkg/model/provider/options/options_test.go | 80 +++++++++ 9 files changed, 608 insertions(+), 4 deletions(-) create mode 100644 pkg/model/provider/anthropic/wrap_transport_test.go create mode 100644 pkg/model/provider/gemini/wrap_transport_test.go create mode 100644 pkg/model/provider/openai/wrap_transport_test.go create mode 100644 pkg/model/provider/options/options_test.go diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index 4e9c5bfe1..bc604363b 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -74,8 +74,16 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro slog.ErrorContext(ctx, "Anthropic client creation failed", "error", err) return nil, err } + httpClient := httpclient.NewHTTPClient(ctx) + if w := globalOptions.TransportWrapper(); w != nil { + if wrapped := w(httpClient.Transport); wrapped != nil { + httpClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } requestOptions := append([]option.RequestOption{ - option.WithHTTPClient(httpclient.NewHTTPClient(ctx)), + option.WithHTTPClient(httpClient), }, authOpts...) if cfg.BaseURL != "" { requestOptions = append(requestOptions, option.WithBaseURL(cfg.BaseURL)) @@ -127,9 +135,17 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro httpOptions = append(httpOptions, httpclient.WithHeader("X-Cagent-GeneratingTitle", "1")) } + gatewayHTTPClient := httpclient.NewHTTPClient(ctx, httpOptions...) + if w := globalOptions.TransportWrapper(); w != nil { + if wrapped := w(gatewayHTTPClient.Transport); wrapped != nil { + gatewayHTTPClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } clientOptions := []option.RequestOption{ option.WithBaseURL(baseURL), - option.WithHTTPClient(httpclient.NewHTTPClient(ctx, httpOptions...)), + option.WithHTTPClient(gatewayHTTPClient), } if authToken != "" { clientOptions = append(clientOptions, option.WithAuthToken(authToken), option.WithAPIKey(authToken)) diff --git a/pkg/model/provider/anthropic/wrap_transport_test.go b/pkg/model/provider/anthropic/wrap_transport_test.go new file mode 100644 index 000000000..59a0ea2be --- /dev/null +++ b/pkg/model/provider/anthropic/wrap_transport_test.go @@ -0,0 +1,157 @@ +package anthropic + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/model/provider/options" +) + +// countingTransport wraps a base RoundTripper and counts how many times +// RoundTrip is called. +type countingTransport struct { + base http.RoundTripper + calls atomic.Int64 +} + +func (c *countingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + c.calls.Add(1) + return c.base.RoundTrip(req) +} + +// writeMinimalAnthropicSSE writes a bare-minimum valid Anthropic SSE stream +// so that the streaming client does not error before we can observe transport invocation. +func writeMinimalAnthropicSSE(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + + writeEvent := func(eventType string, payload any) { + data, _ := json.Marshal(payload) + _, _ = w.Write([]byte("event: " + eventType + "\n")) + _, _ = w.Write([]byte("data: " + string(data) + "\n\n")) + if flusher != nil { + flusher.Flush() + } + } + + writeEvent("message_start", map[string]any{ + "type": "message_start", + "message": map[string]any{"id": "msg_test", "model": "claude-test", "role": "assistant", "type": "message", "content": []any{}, "stop_reason": nil, "usage": map[string]any{"input_tokens": 5, "output_tokens": 0}}, + }) + writeEvent("content_block_start", map[string]any{ + "type": "content_block_start", + "index": 0, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + writeEvent("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": 0, + "delta": map[string]any{"type": "text_delta", "text": "hi"}, + }) + writeEvent("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": 0, + }) + writeEvent("message_delta", map[string]any{ + "type": "message_delta", + "delta": map[string]any{"stop_reason": "end_turn", "stop_sequence": nil}, + "usage": map[string]any{"output_tokens": 1}, + }) + writeEvent("message_stop", map[string]any{"type": "message_stop"}) +} + +func TestNewClient_TransportWrapperInvokedDirectPath(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeMinimalAnthropicSSE(w) + })) + defer server.Close() + + var counter countingTransport + + cfg := &latest.ModelConfig{ + Provider: "anthropic", + Model: "claude-3-5-haiku-latest", + BaseURL: server.URL, + } + env := environment.NewMapEnvProvider(map[string]string{ + "ANTHROPIC_API_KEY": "test-key", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip has been fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Greater(t, counter.calls.Load(), int64(0), "transport wrapper RoundTrip should have been called at least once") +} + +func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeMinimalAnthropicSSE(w) + })) + defer server.Close() + + var counter countingTransport + + cfg := &latest.ModelConfig{ + Provider: "anthropic", + Model: "claude-3-5-haiku-latest", + } + // server.URL is 127.0.0.1 which IsTrustedDockerURL considers trusted, + // so we must supply the Docker Desktop token. + env := environment.NewMapEnvProvider(map[string]string{ + environment.DockerDesktopTokenEnv: "test-dd-token", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithGateway(server.URL), + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip has been fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Greater(t, counter.calls.Load(), int64(0), "transport wrapper RoundTrip should have been called at least once in gateway path") +} diff --git a/pkg/model/provider/bedrock/client.go b/pkg/model/provider/bedrock/client.go index 282ef56f2..97b6dba04 100644 --- a/pkg/model/provider/bedrock/client.go +++ b/pkg/model/provider/bedrock/client.go @@ -60,6 +60,10 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro opt(&globalOptions) } + if globalOptions.TransportWrapper() != nil { + slog.WarnContext(ctx, "HTTP transport wrapper is set but not applied: Bedrock provider uses the AWS SDK HTTP client") + } + // Check for bearer token - use token_key if specified, otherwise try AWS_BEARER_TOKEN_BEDROCK. // Bearer token is optional: if not provided, falls back to standard AWS credential chain (SigV4). // diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index 593308359..7400753d9 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -103,6 +103,16 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro httpClient = httpclient.NewHTTPClient(ctx) } + if w := globalOptions.TransportWrapper(); w != nil { + if httpClient == nil { + slog.WarnContext(ctx, "HTTP transport wrapper is set but not applied: Gemini Vertex AI backend uses an SDK-managed HTTP client") + } else if wrapped := w(httpClient.Transport); wrapped != nil { + httpClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } + client, err := genai.NewClient(ctx, &genai.ClientConfig{ APIKey: apiKey, Project: project, @@ -168,10 +178,19 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro } } + gatewayHTTPClient := httpclient.NewHTTPClient(ctx, httpOptions...) + if w := globalOptions.TransportWrapper(); w != nil { + if wrapped := w(gatewayHTTPClient.Transport); wrapped != nil { + gatewayHTTPClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } + return genai.NewClient(ctx, &genai.ClientConfig{ APIKey: authToken, Backend: genai.BackendGeminiAPI, - HTTPClient: httpclient.NewHTTPClient(ctx, httpOptions...), + HTTPClient: gatewayHTTPClient, HTTPOptions: httpOpts, }) } diff --git a/pkg/model/provider/gemini/wrap_transport_test.go b/pkg/model/provider/gemini/wrap_transport_test.go new file mode 100644 index 000000000..36ce2a4d0 --- /dev/null +++ b/pkg/model/provider/gemini/wrap_transport_test.go @@ -0,0 +1,153 @@ +package gemini + +import ( + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/model/provider/options" +) + +// geminiCountingTransport wraps a base RoundTripper and counts RoundTrip calls. +type geminiCountingTransport struct { + base http.RoundTripper + calls atomic.Int64 +} + +func (c *geminiCountingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + c.calls.Add(1) + return c.base.RoundTrip(req) +} + +// writeGeminiSSEResponse writes a minimal valid Gemini streaming response. +func writeGeminiSSEResponse(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + + payload := `{"candidates":[{"content":{"parts":[{"text":"hi"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}` + _, _ = fmt.Fprintf(w, "data: %s\n\n", payload) + if flusher != nil { + flusher.Flush() + } +} + +func TestNewClient_TransportWrapperInvokedDirectPath(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeGeminiSSEResponse(w) + })) + defer server.Close() + + var counter geminiCountingTransport + + cfg := &latest.ModelConfig{ + Provider: "google", + Model: "gemini-2.0-flash", + BaseURL: server.URL, + } + env := environment.NewMapEnvProvider(map[string]string{ + "GOOGLE_API_KEY": "test-key", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip has been fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Greater(t, counter.calls.Load(), int64(0), "transport wrapper RoundTrip should have been called at least once") +} + +func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeGeminiSSEResponse(w) + })) + defer server.Close() + + var counter geminiCountingTransport + + cfg := &latest.ModelConfig{ + Provider: "google", + Model: "gemini-2.0-flash", + } + // server.URL is 127.0.0.1 which IsTrustedDockerURL considers trusted, + // so we must supply the Docker Desktop token. + env := environment.NewMapEnvProvider(map[string]string{ + environment.DockerDesktopTokenEnv: "test-dd-token", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithGateway(server.URL), + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip has been fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Greater(t, counter.calls.Load(), int64(0), "transport wrapper RoundTrip should have been called at least once in gateway path") +} + +func TestNewClient_TransportWrapperVertexAIWarns(t *testing.T) { + // When a Vertex AI backend is used (project+location configured), the genai + // SDK manages its own HTTP client. The transport wrapper cannot be applied. + // Verify that NewClient succeeds (no error) and the wrapper function itself + // is not called — the caller receives a slog warning instead. + var wrapperInvoked bool + + cfg := &latest.ModelConfig{ + Provider: "google", + Model: "gemini-2.0-flash", + ProviderOpts: map[string]any{ + "project": "test-project", + "location": "us-central1", + }, + } + env := environment.NewMapEnvProvider(map[string]string{}) + + _, err := NewClient(t.Context(), cfg, env, + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + wrapperInvoked = true + return &geminiCountingTransport{base: base} + }), + ) + // NewClient may fail because Vertex AI requires real ADC credentials; + // we only care that the wrapper function itself was NOT invoked. + _ = err + assert.False(t, wrapperInvoked, "wrapper function should not be invoked for Vertex AI backend") +} diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index 3c5a54a98..322a5aefe 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -103,6 +103,13 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro clientOptions = append(clientOptions, option.WithMiddleware(oaistream.ErrorBodyMiddleware())) httpClient := httpclient.NewHTTPClient(ctx) + if w := globalOptions.TransportWrapper(); w != nil { + if wrapped := w(httpClient.Transport); wrapped != nil { + httpClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } clientOptions = append(clientOptions, option.WithHTTPClient(httpClient)) client := openai.NewClient(clientOptions...) @@ -149,9 +156,17 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro httpOptions = append(httpOptions, httpclient.WithHeader("X-Cagent-GeneratingTitle", "1")) } + gatewayHTTPClient := httpclient.NewHTTPClient(ctx, httpOptions...) + if w := globalOptions.TransportWrapper(); w != nil { + if wrapped := w(gatewayHTTPClient.Transport); wrapped != nil { + gatewayHTTPClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } clientOptions := []option.RequestOption{ option.WithBaseURL(baseURL), - option.WithHTTPClient(httpclient.NewHTTPClient(ctx, httpOptions...)), + option.WithHTTPClient(gatewayHTTPClient), option.WithMiddleware(oaistream.ErrorBodyMiddleware()), } if authToken != "" { diff --git a/pkg/model/provider/openai/wrap_transport_test.go b/pkg/model/provider/openai/wrap_transport_test.go new file mode 100644 index 000000000..417237b17 --- /dev/null +++ b/pkg/model/provider/openai/wrap_transport_test.go @@ -0,0 +1,112 @@ +package openai + +import ( + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/model/provider/options" +) + +// oaiCountingTransport wraps a base RoundTripper and counts RoundTrip calls. +type oaiCountingTransport struct { + base http.RoundTripper + calls atomic.Int64 +} + +func (c *oaiCountingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + c.calls.Add(1) + return c.base.RoundTrip(req) +} + +func TestNewClient_TransportWrapperInvokedDirectPath(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeSSEResponse(w) + })) + defer server.Close() + + var counter oaiCountingTransport + + cfg := &latest.ModelConfig{ + Provider: "openai", + Model: "gpt-4o", + BaseURL: server.URL, + TokenKey: "OPENAI_API_KEY", + } + env := environment.NewMapEnvProvider(map[string]string{ + "OPENAI_API_KEY": "test-key", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip is fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Greater(t, counter.calls.Load(), int64(0), "transport wrapper RoundTrip should have been called at least once") +} + +func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeSSEResponse(w) + })) + defer server.Close() + + var counter oaiCountingTransport + + cfg := &latest.ModelConfig{ + Provider: "openai", + Model: "gpt-4o", + } + // server.URL is 127.0.0.1 which IsTrustedDockerURL considers trusted, + // so we must supply the Docker Desktop token. + env := environment.NewMapEnvProvider(map[string]string{ + environment.DockerDesktopTokenEnv: "test-dd-token", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithGateway(server.URL), + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip is fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Greater(t, counter.calls.Load(), int64(0), "transport wrapper RoundTrip should have been called at least once in gateway path") +} diff --git a/pkg/model/provider/options/options.go b/pkg/model/provider/options/options.go index 329e2b276..066df89df 100644 --- a/pkg/model/provider/options/options.go +++ b/pkg/model/provider/options/options.go @@ -1,6 +1,8 @@ package options import ( + "net/http" + "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/modelsdev" ) @@ -13,6 +15,7 @@ type ModelOptions struct { maxTokens int64 providers map[string]latest.ProviderConfig modelsDevStore *modelsdev.Store + transportWrapper func(http.RoundTripper) http.RoundTripper } func (c *ModelOptions) Gateway() string { @@ -43,6 +46,12 @@ func (c *ModelOptions) ModelsDevStore() *modelsdev.Store { return c.modelsDevStore } +// TransportWrapper returns the HTTP transport wrapper function registered via +// WithHTTPTransportWrapper, or nil if none was set. +func (c *ModelOptions) TransportWrapper() func(http.RoundTripper) http.RoundTripper { + return c.transportWrapper +} + type Opt func(*ModelOptions) func WithGateway(gateway string) Opt { @@ -87,6 +96,42 @@ func WithModelsDevStore(store *modelsdev.Store) Opt { } } +// WithHTTPTransportWrapper registers a function that wraps the HTTP transport +// used by provider clients (Anthropic, OpenAI, and Gemini with the Gemini API +// backend). The function receives the transport that docker-agent built +// (including OTel instrumentation, SSE decompression fix, and Desktop proxy +// support) and must return a new RoundTripper that delegates to it. The wrapper +// is applied in both direct mode and gateway/proxy mode. +// +// Call-frequency note: in direct mode the wrapper is invoked once at client +// construction time; in gateway mode it is invoked on every LLM request +// (because gateway clients are rebuilt on each call to refresh short-lived +// auth tokens). Wrappers with per-call side effects (metrics, token rotation) +// will therefore be called more frequently in gateway mode. +// +// Limitations: +// - OpenAI clients configured with transport=websocket bypass the HTTP +// transport layer entirely; the wrapper is not applied in that mode. +// - Gemini clients using the Vertex AI backend (project/location config or +// GOOGLE_GENAI_USE_VERTEXAI) rely on the genai SDK's default HTTP client; +// the wrapper is not applied and a warning is logged. +// - The Bedrock provider uses the AWS SDK's own HTTP client; the wrapper is +// not applied and a warning is logged. +// +// The wrapper function must return a non-nil RoundTripper; returning nil is a +// no-op (a warning is logged and the original transport is kept). +// +// Example — inject a bearer token on every outbound LLM request: +// +// options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { +// return &bearerTransport{token: myToken, base: base} +// }) +func WithHTTPTransportWrapper(fn func(base http.RoundTripper) http.RoundTripper) Opt { + return func(cfg *ModelOptions) { + cfg.transportWrapper = fn + } +} + // FromModelOptions converts a concrete ModelOptions value into a slice of // Opt configuration functions. Later Opts override earlier ones when applied. func FromModelOptions(m ModelOptions) []Opt { @@ -112,5 +157,8 @@ func FromModelOptions(m ModelOptions) []Opt { if m.modelsDevStore != nil { out = append(out, WithModelsDevStore(m.modelsDevStore)) } + if m.transportWrapper != nil { + out = append(out, WithHTTPTransportWrapper(m.transportWrapper)) + } return out } diff --git a/pkg/model/provider/options/options_test.go b/pkg/model/provider/options/options_test.go new file mode 100644 index 000000000..84f5e1479 --- /dev/null +++ b/pkg/model/provider/options/options_test.go @@ -0,0 +1,80 @@ +package options + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// sentinelTransport is a minimal http.RoundTripper used only for identity checks. +type sentinelTransport struct{ base http.RoundTripper } + +func (s *sentinelTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return s.base.RoundTrip(req) +} + +func TestWithHTTPTransportWrapper_SetAndGet(t *testing.T) { + var called bool + wrapFn := func(base http.RoundTripper) http.RoundTripper { + called = true + return &sentinelTransport{base: base} + } + + var opts ModelOptions + WithHTTPTransportWrapper(wrapFn)(&opts) + + got := opts.TransportWrapper() + require.NotNil(t, got) + + // Verify invoking the returned wrapper marks called=true and returns a non-nil transport. + result := got(http.DefaultTransport) + assert.True(t, called) + assert.NotNil(t, result) +} + +func TestTransportWrapper_NilByDefault(t *testing.T) { + var opts ModelOptions + assert.Nil(t, opts.TransportWrapper()) +} + +func TestFromModelOptions_RoundTripsTransportWrapper(t *testing.T) { + var wrapperInvoked bool + wrapFn := func(base http.RoundTripper) http.RoundTripper { + wrapperInvoked = true + return &sentinelTransport{base: base} + } + + var src ModelOptions + WithHTTPTransportWrapper(wrapFn)(&src) + + opts := FromModelOptions(src) + require.NotEmpty(t, opts) + + var dst ModelOptions + for _, o := range opts { + o(&dst) + } + + got := dst.TransportWrapper() + require.NotNil(t, got) + + result := got(http.DefaultTransport) + assert.True(t, wrapperInvoked) + assert.NotNil(t, result) +} + +func TestFromModelOptions_NilWrapperNotIncluded(t *testing.T) { + // A ModelOptions with no transport wrapper should not add a + // WithHTTPTransportWrapper opt, so TransportWrapper() stays nil. + var src ModelOptions + opts := FromModelOptions(src) + + var dst ModelOptions + for _, o := range opts { + o(&dst) + } + + assert.Nil(t, dst.TransportWrapper()) +} From 59057b58197fe5d49101abf33dd1b132debe3573 Mon Sep 17 00:00:00 2001 From: Simon Ferquel's Clanker Date: Fri, 12 Jun 2026 13:13:56 +0000 Subject: [PATCH 2/7] fix: address lint failures and WebSocket wrapper warning - Replace assert.Greater(t, x, int64(0)) with assert.Positive(t, x) in all three wrap_transport_test files to satisfy testifylint - Emit slog.WarnContext when transport wrapper is set but OpenAI WebSocket transport is in use (consistent with Bedrock and Gemini Vertex AI) --- pkg/model/provider/anthropic/wrap_transport_test.go | 4 ++-- pkg/model/provider/gemini/wrap_transport_test.go | 4 ++-- pkg/model/provider/openai/client.go | 3 +++ pkg/model/provider/openai/wrap_transport_test.go | 4 ++-- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pkg/model/provider/anthropic/wrap_transport_test.go b/pkg/model/provider/anthropic/wrap_transport_test.go index 59a0ea2be..5eac2140e 100644 --- a/pkg/model/provider/anthropic/wrap_transport_test.go +++ b/pkg/model/provider/anthropic/wrap_transport_test.go @@ -110,7 +110,7 @@ func TestNewClient_TransportWrapperInvokedDirectPath(t *testing.T) { } } - assert.Greater(t, counter.calls.Load(), int64(0), "transport wrapper RoundTrip should have been called at least once") + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once") } func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { @@ -153,5 +153,5 @@ func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { } } - assert.Greater(t, counter.calls.Load(), int64(0), "transport wrapper RoundTrip should have been called at least once in gateway path") + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once in gateway path") } diff --git a/pkg/model/provider/gemini/wrap_transport_test.go b/pkg/model/provider/gemini/wrap_transport_test.go index 36ce2a4d0..fc63fc483 100644 --- a/pkg/model/provider/gemini/wrap_transport_test.go +++ b/pkg/model/provider/gemini/wrap_transport_test.go @@ -77,7 +77,7 @@ func TestNewClient_TransportWrapperInvokedDirectPath(t *testing.T) { } } - assert.Greater(t, counter.calls.Load(), int64(0), "transport wrapper RoundTrip should have been called at least once") + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once") } func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { @@ -120,7 +120,7 @@ func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { } } - assert.Greater(t, counter.calls.Load(), int64(0), "transport wrapper RoundTrip should have been called at least once in gateway path") + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once in gateway path") } func TestNewClient_TransportWrapperVertexAIWarns(t *testing.T) { diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index 322a5aefe..bf869f04b 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -195,6 +195,9 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro if getTransport(cfg) == "websocket" && globalOptions.Gateway() == "" { baseURL := cmp.Or(cfg.BaseURL, "https://api.openai.com/v1") client.wsPool = newWSPool(httpToWSURL(baseURL), client.buildWSHeaderFn()) + if globalOptions.TransportWrapper() != nil { + slog.WarnContext(ctx, "HTTP transport wrapper is set but not applied: WebSocket transport uses an SDK-managed connection") + } } return client, nil diff --git a/pkg/model/provider/openai/wrap_transport_test.go b/pkg/model/provider/openai/wrap_transport_test.go index 417237b17..e8b8bfffb 100644 --- a/pkg/model/provider/openai/wrap_transport_test.go +++ b/pkg/model/provider/openai/wrap_transport_test.go @@ -65,7 +65,7 @@ func TestNewClient_TransportWrapperInvokedDirectPath(t *testing.T) { } } - assert.Greater(t, counter.calls.Load(), int64(0), "transport wrapper RoundTrip should have been called at least once") + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once") } func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { @@ -108,5 +108,5 @@ func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { } } - assert.Greater(t, counter.calls.Load(), int64(0), "transport wrapper RoundTrip should have been called at least once in gateway path") + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once in gateway path") } From d7c54111239b447ab02993f5233a1220f0857e24 Mon Sep 17 00:00:00 2001 From: Simon Ferquel's Clanker Date: Fri, 12 Jun 2026 13:25:33 +0000 Subject: [PATCH 3/7] feat: wire TransportWrapper into Bedrock provider The AWS SDK v2 supports a custom HTTP client via bedrockruntime.Options.HTTPClient (which accepts any aws.HTTPClient, and *http.Client satisfies that interface). Replace the warn-and-skip approach with the same pattern used for Anthropic, OpenAI, and Gemini: - Build the docker-agent HTTP client (OTel, SSE decompression, user-agent) - If a bearer token is configured, chain bearerTokenTransport on top of that base (previously it wrapped http.DefaultTransport, bypassing OTel/SSE entirely) - Apply the caller's transport wrapper over the full chain - Inject the result into bedrockruntime.Options.HTTPClient Also remove the Bedrock limitation bullet from the WithHTTPTransportWrapper docstring since it is now fully supported. --- pkg/model/provider/bedrock/client.go | 40 +++++++++++++++++++++------ pkg/model/provider/options/options.go | 2 -- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/pkg/model/provider/bedrock/client.go b/pkg/model/provider/bedrock/client.go index 97b6dba04..c9d60db27 100644 --- a/pkg/model/provider/bedrock/client.go +++ b/pkg/model/provider/bedrock/client.go @@ -18,6 +18,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/httpclient" "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/model/provider/providerutil" @@ -64,7 +65,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro slog.WarnContext(ctx, "HTTP transport wrapper is set but not applied: Bedrock provider uses the AWS SDK HTTP client") } - // Check for bearer token - use token_key if specified, otherwise try AWS_BEARER_TOKEN_BEDROCK. + // Check for bearer token // Bearer token is optional: if not provided, falls back to standard AWS credential chain (SigV4). // // NOTE: Manual token handling is required because aws-sdk-go-v2's default credential chain @@ -81,6 +82,28 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro bearerToken, _ = env.Get(ctx, "AWS_BEARER_TOKEN_BEDROCK") } + // Build the docker-agent HTTP client (OTel instrumentation, SSE decompression, + // Desktop proxy support) so transport-level concerns apply to Bedrock too. + httpClient := httpclient.NewHTTPClient(ctx) + + // If a bearer token is set, chain it on top of the base transport so auth + // headers are injected without replacing the rest of the transport stack. + if bearerToken != "" { + httpClient.Transport = &bearerTokenTransport{ + token: bearerToken, + base: httpClient.Transport, + } + } + + // Apply the transport wrapper, if registered, over the full chain. + if w := globalOptions.TransportWrapper(); w != nil { + if wrapped := w(httpClient.Transport); wrapped != nil { + httpClient.Transport = wrapped + } else { + slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") + } + } + // Build AWS config using default credential chain awsCfg, err := buildAWSConfig(ctx, cfg, env) if err != nil { @@ -98,19 +121,18 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro }) } - // If bearer token is set, use it instead of SigV4 + // Inject our HTTP client (which carries OTel, SSE, bearer token if set, and + // any caller-registered transport wrapper) into the Bedrock runtime options. if bearerToken != "" { slog.DebugContext(ctx, "Bedrock using bearer token authentication") clientOpts = append(clientOpts, func(o *bedrockruntime.Options) { // Use anonymous credentials to skip SigV4 signing o.Credentials = aws.AnonymousCredentials{} - // Add bearer token via custom HTTP client - o.HTTPClient = &http.Client{ - Transport: &bearerTokenTransport{ - token: bearerToken, - base: http.DefaultTransport, - }, - } + o.HTTPClient = httpClient + }) + } else { + clientOpts = append(clientOpts, func(o *bedrockruntime.Options) { + o.HTTPClient = httpClient }) } diff --git a/pkg/model/provider/options/options.go b/pkg/model/provider/options/options.go index 066df89df..d9898a85a 100644 --- a/pkg/model/provider/options/options.go +++ b/pkg/model/provider/options/options.go @@ -115,8 +115,6 @@ func WithModelsDevStore(store *modelsdev.Store) Opt { // - Gemini clients using the Vertex AI backend (project/location config or // GOOGLE_GENAI_USE_VERTEXAI) rely on the genai SDK's default HTTP client; // the wrapper is not applied and a warning is logged. -// - The Bedrock provider uses the AWS SDK's own HTTP client; the wrapper is -// not applied and a warning is logged. // // The wrapper function must return a non-nil RoundTripper; returning nil is a // no-op (a warning is logged and the original transport is kept). From b387a68d9b16ade3571f16322fbefb27358ec1ab Mon Sep 17 00:00:00 2001 From: Simon Ferquel's Clanker Date: Fri, 12 Jun 2026 13:51:37 +0000 Subject: [PATCH 4/7] fix: remove stale TransportWrapper warning from Bedrock provider The warning 'HTTP transport wrapper is set but not applied' was a leftover from before the wrapper was wired in (d7c5411). The wrapper is now applied at line 99 via bedrockruntime.Options.HTTPClient, so the upfront warn block was wrong and is removed. --- pkg/model/provider/bedrock/client.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pkg/model/provider/bedrock/client.go b/pkg/model/provider/bedrock/client.go index c9d60db27..fb6bcce1e 100644 --- a/pkg/model/provider/bedrock/client.go +++ b/pkg/model/provider/bedrock/client.go @@ -61,10 +61,6 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro opt(&globalOptions) } - if globalOptions.TransportWrapper() != nil { - slog.WarnContext(ctx, "HTTP transport wrapper is set but not applied: Bedrock provider uses the AWS SDK HTTP client") - } - // Check for bearer token // Bearer token is optional: if not provided, falls back to standard AWS credential chain (SigV4). // From 5bf51281746be7b7952822806a67680aa1d4cbdc Mon Sep 17 00:00:00 2001 From: Simon Ferquel's Clanker Date: Fri, 12 Jun 2026 14:06:00 +0000 Subject: [PATCH 5/7] feat: fall back to SSE when transport wrapper is set and transport=websocket gorilla/websocket dials raw TCP/TLS and never calls http.RoundTripper, so an HTTP transport wrapper cannot intercept WebSocket connections. Rather than warning and silently dropping the wrapper, add TransportWrapper() == nil to the WebSocket guards in both NewClient and CreateResponseStream so that callers who register a wrapper automatically get SSE (where the wrapper is fully applied) instead of WebSocket. Changes: - NewClient: wsPool is not created when TransportWrapper() != nil - CreateResponseStream: split the fallback log into two debug messages -- one for the gateway case, one for the transport-wrapper case - Remove the now-unnecessary slog.Warn about WebSocket bypassing the wrapper - Add TestNewClient_WebSocketFallsBackToSSEWhenTransportWrapperSet --- pkg/model/provider/openai/client.go | 18 ++++++---- .../provider/openai/wrap_transport_test.go | 36 +++++++++++++++++++ 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index bf869f04b..21ea8e7e3 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -192,12 +192,12 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro // Pre-create the WebSocket pool when the transport is configured. // The pool is cheap (no connections opened until the first Stream call) // and eager init avoids a data race on the lazy path. - if getTransport(cfg) == "websocket" && globalOptions.Gateway() == "" { + // WebSocket is also skipped when an HTTP transport wrapper is registered: + // gorilla/websocket dials raw TCP and never calls http.RoundTripper, so the + // wrapper cannot be applied. Fall back to SSE so the wrapper covers all calls. + if getTransport(cfg) == "websocket" && globalOptions.Gateway() == "" && globalOptions.TransportWrapper() == nil { baseURL := cmp.Or(cfg.BaseURL, "https://api.openai.com/v1") client.wsPool = newWSPool(httpToWSURL(baseURL), client.buildWSHeaderFn()) - if globalOptions.TransportWrapper() != nil { - slog.WarnContext(ctx, "HTTP transport wrapper is set but not applied: WebSocket transport uses an SDK-managed connection") - } } return client, nil @@ -530,10 +530,13 @@ func (c *Client) CreateResponseStream( // Choose transport: WebSocket or SSE (default). // WebSocket is disabled when using a Gateway since most gateways don't support it. + // WebSocket is also disabled when an HTTP transport wrapper is registered: gorilla/websocket + // dials raw TCP and never calls http.RoundTripper, so the wrapper cannot intercept those + // connections. Fall back to SSE so the wrapper applies to all requests. transport := getTransport(&c.ModelConfig) trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage - if transport == "websocket" && c.ModelOptions.Gateway() == "" { + if transport == "websocket" && c.ModelOptions.Gateway() == "" && c.ModelOptions.TransportWrapper() == nil { stream, err := c.createWebSocketStream(ctx, params) if err != nil { slog.WarnContext(ctx, "WebSocket stream failed, falling back to SSE", "error", err) @@ -542,10 +545,13 @@ func (c *Client) CreateResponseStream( slog.DebugContext(ctx, "OpenAI responses WebSocket stream created successfully", "model", c.ModelConfig.Model) return newResponseStreamAdapter(stream, trackUsage), nil } - } else if transport == "websocket" { + } else if transport == "websocket" && c.ModelOptions.Gateway() != "" { slog.DebugContext(ctx, "WebSocket transport requested but Gateway is configured, using SSE", "model", c.ModelConfig.Model, "gateway", c.ModelOptions.Gateway()) + } else if transport == "websocket" { + slog.DebugContext(ctx, "WebSocket transport requested but HTTP transport wrapper is set, using SSE", + "model", c.ModelConfig.Model) } client, err := c.clientFn(ctx) diff --git a/pkg/model/provider/openai/wrap_transport_test.go b/pkg/model/provider/openai/wrap_transport_test.go index e8b8bfffb..95d4afdd8 100644 --- a/pkg/model/provider/openai/wrap_transport_test.go +++ b/pkg/model/provider/openai/wrap_transport_test.go @@ -110,3 +110,39 @@ func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once in gateway path") } + +// TestNewClient_WebSocketFallsBackToSSEWhenTransportWrapperSet verifies that +// configuring transport=websocket together with a transport wrapper causes the +// client to fall back to SSE (no wsPool created). gorilla/websocket bypasses +// http.RoundTripper, so websocket would silently drop the wrapper; the fallback +// ensures the wrapper covers every outbound request. +func TestNewClient_WebSocketFallsBackToSSEWhenTransportWrapperSet(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeSSEResponse(w) + })) + defer server.Close() + + var counter oaiCountingTransport + + cfg := &latest.ModelConfig{ + Provider: "openai", + Model: "gpt-4o-realtime-preview", + BaseURL: server.URL, + TokenKey: "OPENAI_API_KEY", + ProviderOpts: map[string]any{"transport": "websocket"}, + } + env := environment.NewMapEnvProvider(map[string]string{ + "OPENAI_API_KEY": "test-key", + }) + + client, err := NewClient(t.Context(), cfg, env, + options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { + counter.base = base + return &counter + }), + ) + require.NoError(t, err) + + // wsPool must not be created when a transport wrapper is registered. + assert.Nil(t, client.wsPool, "wsPool should be nil when a transport wrapper is set") +} From 05e6902354e653db16195c2492fc77be7ce6b23d Mon Sep 17 00:00:00 2001 From: Simon Ferquel's Clanker Date: Fri, 12 Jun 2026 14:13:48 +0000 Subject: [PATCH 6/7] =?UTF-8?q?fix:=20lint=20and=20apply=20Vertex=20AI=20?= =?UTF-8?q?=E2=86=92=20GeminiAPI=20fallback=20when=20transport=20wrapper?= =?UTF-8?q?=20is=20set?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Lint fix: - Rewrite the if-else chain in openai/client.go (CreateResponseStream) as a switch to satisfy the gocritic ifElseChain linter rule Vertex AI fallback (mirrors WebSocket → SSE): - In gemini/client.go, compute useVertexAI = wantVertexAI && TransportWrapper() == nil - When a transport wrapper is registered, override the Vertex AI backend selection and use BackendGeminiAPI + httpclient.NewHTTPClient so the wrapper can intercept all requests - Remove the slog.Warn that previously fired when wrapper + Vertex AI coexisted - Update wrap_transport_test.go: replace TestNewClient_TransportWrapperVertexAIWarns with TestNewClient_TransportWrapperVertexAIFallsBackToGeminiAPI which asserts the wrapper IS called on a real stream when project/location are set --- pkg/model/provider/gemini/client.go | 25 +++++++--- .../provider/gemini/wrap_transport_test.go | 49 +++++++++++++------ pkg/model/provider/openai/client.go | 7 +-- 3 files changed, 56 insertions(+), 25 deletions(-) diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index 7400753d9..7a4bd8488 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -61,8 +61,19 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro project string location string ) - // project/location take priority over API key, like in the genai client. - if cfg.ProviderOpts["project"] != nil || cfg.ProviderOpts["location"] != nil { + // Determine whether Vertex AI would normally be used, then check whether + // an HTTP transport wrapper forces a fallback to BackendGeminiAPI. + // The Vertex AI backend relies on ADC-managed HTTP clients that bypass + // http.RoundTripper, so the wrapper cannot be applied there. + _, useVertexAIEnv := env.Get(ctx, "GOOGLE_GENAI_USE_VERTEXAI") + wantVertexAI := cfg.ProviderOpts["project"] != nil || cfg.ProviderOpts["location"] != nil || useVertexAIEnv + useVertexAI := wantVertexAI && globalOptions.TransportWrapper() == nil + + if wantVertexAI && !useVertexAI { + slog.DebugContext(ctx, "Vertex AI requested but HTTP transport wrapper is set, falling back to GeminiAPI backend") + } + + if useVertexAI && (cfg.ProviderOpts["project"] != nil || cfg.ProviderOpts["location"] != nil) { var err error project, err = environment.Expand(ctx, providerOption(cfg, "project"), env) @@ -82,12 +93,12 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro } backend = genai.BackendVertexAI - httpClient = nil // Use default client - } else if _, exist := env.Get(ctx, "GOOGLE_GENAI_USE_VERTEXAI"); exist { + httpClient = nil // Use ADC-managed client + } else if useVertexAI { project, _ = env.Get(ctx, "GOOGLE_CLOUD_PROJECT") location, _ = env.Get(ctx, "GOOGLE_CLOUD_LOCATION") backend = genai.BackendVertexAI - httpClient = nil // Use default client + httpClient = nil // Use ADC-managed client } else { if value, exist := env.Get(ctx, "GEMINI_API_KEY"); exist { apiKey = value @@ -104,9 +115,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro } if w := globalOptions.TransportWrapper(); w != nil { - if httpClient == nil { - slog.WarnContext(ctx, "HTTP transport wrapper is set but not applied: Gemini Vertex AI backend uses an SDK-managed HTTP client") - } else if wrapped := w(httpClient.Transport); wrapped != nil { + if wrapped := w(httpClient.Transport); wrapped != nil { httpClient.Transport = wrapped } else { slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport") diff --git a/pkg/model/provider/gemini/wrap_transport_test.go b/pkg/model/provider/gemini/wrap_transport_test.go index fc63fc483..ee524e4f5 100644 --- a/pkg/model/provider/gemini/wrap_transport_test.go +++ b/pkg/model/provider/gemini/wrap_transport_test.go @@ -123,31 +123,52 @@ func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) { assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once in gateway path") } -func TestNewClient_TransportWrapperVertexAIWarns(t *testing.T) { - // When a Vertex AI backend is used (project+location configured), the genai - // SDK manages its own HTTP client. The transport wrapper cannot be applied. - // Verify that NewClient succeeds (no error) and the wrapper function itself - // is not called — the caller receives a slog warning instead. - var wrapperInvoked bool +// TestNewClient_TransportWrapperVertexAIFallsBackToGeminiAPI verifies that when +// project/location are configured (Vertex AI) but a transport wrapper is also +// set, the client automatically falls back to BackendGeminiAPI so the wrapper +// can be applied. This mirrors the WebSocket→SSE fallback in the OpenAI provider. +func TestNewClient_TransportWrapperVertexAIFallsBackToGeminiAPI(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeGeminiSSEResponse(w) + })) + defer server.Close() + + var counter geminiCountingTransport cfg := &latest.ModelConfig{ Provider: "google", Model: "gemini-2.0-flash", + BaseURL: server.URL, ProviderOpts: map[string]any{ "project": "test-project", "location": "us-central1", }, } - env := environment.NewMapEnvProvider(map[string]string{}) + // GOOGLE_API_KEY is required for the GeminiAPI fallback path. + env := environment.NewMapEnvProvider(map[string]string{ + "GOOGLE_API_KEY": "test-key", + }) - _, err := NewClient(t.Context(), cfg, env, + client, err := NewClient(t.Context(), cfg, env, options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper { - wrapperInvoked = true - return &geminiCountingTransport{base: base} + counter.base = base + return &counter }), ) - // NewClient may fail because Vertex AI requires real ADC credentials; - // we only care that the wrapper function itself was NOT invoked. - _ = err - assert.False(t, wrapperInvoked, "wrapper function should not be invoked for Vertex AI backend") + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hello"}, + }, nil) + require.NoError(t, err) + defer stream.Close() + + // Drain the stream so RoundTrip has been fully exercised. + for { + if _, err := stream.Recv(); err != nil { + break + } + } + + assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once (GeminiAPI fallback)") } diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index 21ea8e7e3..430d5d8aa 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -536,7 +536,8 @@ func (c *Client) CreateResponseStream( transport := getTransport(&c.ModelConfig) trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage - if transport == "websocket" && c.ModelOptions.Gateway() == "" && c.ModelOptions.TransportWrapper() == nil { + switch { + case transport == "websocket" && c.ModelOptions.Gateway() == "" && c.ModelOptions.TransportWrapper() == nil: stream, err := c.createWebSocketStream(ctx, params) if err != nil { slog.WarnContext(ctx, "WebSocket stream failed, falling back to SSE", "error", err) @@ -545,11 +546,11 @@ func (c *Client) CreateResponseStream( slog.DebugContext(ctx, "OpenAI responses WebSocket stream created successfully", "model", c.ModelConfig.Model) return newResponseStreamAdapter(stream, trackUsage), nil } - } else if transport == "websocket" && c.ModelOptions.Gateway() != "" { + case transport == "websocket" && c.ModelOptions.Gateway() != "": slog.DebugContext(ctx, "WebSocket transport requested but Gateway is configured, using SSE", "model", c.ModelConfig.Model, "gateway", c.ModelOptions.Gateway()) - } else if transport == "websocket" { + case transport == "websocket": slog.DebugContext(ctx, "WebSocket transport requested but HTTP transport wrapper is set, using SSE", "model", c.ModelConfig.Model) } From b00cf0e19d34bc5948f1f140592f1e83b2a6777e Mon Sep 17 00:00:00 2001 From: Simon Ferquel's Clanker Date: Fri, 12 Jun 2026 14:21:32 +0000 Subject: [PATCH 7/7] fix: rewrite gemini client if-else chain as switch (gocritic ifElseChain) --- pkg/model/provider/gemini/client.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index 7a4bd8488..c7e2e4394 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -73,7 +73,8 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro slog.DebugContext(ctx, "Vertex AI requested but HTTP transport wrapper is set, falling back to GeminiAPI backend") } - if useVertexAI && (cfg.ProviderOpts["project"] != nil || cfg.ProviderOpts["location"] != nil) { + switch { + case useVertexAI && (cfg.ProviderOpts["project"] != nil || cfg.ProviderOpts["location"] != nil): var err error project, err = environment.Expand(ctx, providerOption(cfg, "project"), env) @@ -94,12 +95,12 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro backend = genai.BackendVertexAI httpClient = nil // Use ADC-managed client - } else if useVertexAI { + case useVertexAI: project, _ = env.Get(ctx, "GOOGLE_CLOUD_PROJECT") location, _ = env.Get(ctx, "GOOGLE_CLOUD_LOCATION") backend = genai.BackendVertexAI httpClient = nil // Use ADC-managed client - } else { + default: if value, exist := env.Get(ctx, "GEMINI_API_KEY"); exist { apiKey = value }