Skip to content
20 changes: 18 additions & 2 deletions pkg/model/provider/anthropic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,16 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
slog.ErrorContext(ctx, "Anthropic client creation failed", "error", err)
return nil, err
}
httpClient := httpclient.NewHTTPClient(ctx)
if w := globalOptions.TransportWrapper(); w != nil {
if wrapped := w(httpClient.Transport); wrapped != nil {
httpClient.Transport = wrapped
} else {
slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport")
}
}
requestOptions := append([]option.RequestOption{
option.WithHTTPClient(httpclient.NewHTTPClient(ctx)),
option.WithHTTPClient(httpClient),
}, authOpts...)
if cfg.BaseURL != "" {
requestOptions = append(requestOptions, option.WithBaseURL(cfg.BaseURL))
Expand Down Expand Up @@ -127,9 +135,17 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
httpOptions = append(httpOptions, httpclient.WithHeader("X-Cagent-GeneratingTitle", "1"))
}

gatewayHTTPClient := httpclient.NewHTTPClient(ctx, httpOptions...)
if w := globalOptions.TransportWrapper(); w != nil {
if wrapped := w(gatewayHTTPClient.Transport); wrapped != nil {
gatewayHTTPClient.Transport = wrapped
} else {
slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport")
}
}
clientOptions := []option.RequestOption{
option.WithBaseURL(baseURL),
option.WithHTTPClient(httpclient.NewHTTPClient(ctx, httpOptions...)),
option.WithHTTPClient(gatewayHTTPClient),
}
if authToken != "" {
clientOptions = append(clientOptions, option.WithAuthToken(authToken), option.WithAPIKey(authToken))
Expand Down
157 changes: 157 additions & 0 deletions pkg/model/provider/anthropic/wrap_transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package anthropic

import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/environment"
"github.com/docker/docker-agent/pkg/model/provider/options"
)

// countingTransport wraps a base RoundTripper and counts how many times
// RoundTrip is called.
type countingTransport struct {
base http.RoundTripper
calls atomic.Int64
}

func (c *countingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
c.calls.Add(1)
return c.base.RoundTrip(req)
}

// writeMinimalAnthropicSSE writes a bare-minimum valid Anthropic SSE stream
// so that the streaming client does not error before we can observe transport invocation.
func writeMinimalAnthropicSSE(w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)

writeEvent := func(eventType string, payload any) {
data, _ := json.Marshal(payload)
_, _ = w.Write([]byte("event: " + eventType + "\n"))
_, _ = w.Write([]byte("data: " + string(data) + "\n\n"))
if flusher != nil {
flusher.Flush()
}
}

writeEvent("message_start", map[string]any{
"type": "message_start",
"message": map[string]any{"id": "msg_test", "model": "claude-test", "role": "assistant", "type": "message", "content": []any{}, "stop_reason": nil, "usage": map[string]any{"input_tokens": 5, "output_tokens": 0}},
})
writeEvent("content_block_start", map[string]any{
"type": "content_block_start",
"index": 0,
"content_block": map[string]any{
"type": "text",
"text": "",
},
})
writeEvent("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]any{"type": "text_delta", "text": "hi"},
})
writeEvent("content_block_stop", map[string]any{
"type": "content_block_stop",
"index": 0,
})
writeEvent("message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn", "stop_sequence": nil},
"usage": map[string]any{"output_tokens": 1},
})
writeEvent("message_stop", map[string]any{"type": "message_stop"})
}

func TestNewClient_TransportWrapperInvokedDirectPath(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
writeMinimalAnthropicSSE(w)
}))
defer server.Close()

var counter countingTransport

cfg := &latest.ModelConfig{
Provider: "anthropic",
Model: "claude-3-5-haiku-latest",
BaseURL: server.URL,
}
env := environment.NewMapEnvProvider(map[string]string{
"ANTHROPIC_API_KEY": "test-key",
})

client, err := NewClient(t.Context(), cfg, env,
options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper {
counter.base = base
return &counter
}),
)
require.NoError(t, err)

stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{
{Role: chat.MessageRoleUser, Content: "hello"},
}, nil)
require.NoError(t, err)
defer stream.Close()

// Drain the stream so RoundTrip has been fully exercised.
for {
if _, err := stream.Recv(); err != nil {
break
}
}

assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once")
}

func TestNewClient_TransportWrapperInvokedGatewayPath(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
writeMinimalAnthropicSSE(w)
}))
defer server.Close()

var counter countingTransport

cfg := &latest.ModelConfig{
Provider: "anthropic",
Model: "claude-3-5-haiku-latest",
}
// server.URL is 127.0.0.1 which IsTrustedDockerURL considers trusted,
// so we must supply the Docker Desktop token.
env := environment.NewMapEnvProvider(map[string]string{
environment.DockerDesktopTokenEnv: "test-dd-token",
})

client, err := NewClient(t.Context(), cfg, env,
options.WithGateway(server.URL),
options.WithHTTPTransportWrapper(func(base http.RoundTripper) http.RoundTripper {
counter.base = base
return &counter
}),
)
require.NoError(t, err)

stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{
{Role: chat.MessageRoleUser, Content: "hello"},
}, nil)
require.NoError(t, err)
defer stream.Close()

// Drain the stream so RoundTrip has been fully exercised.
for {
if _, err := stream.Recv(); err != nil {
break
}
}

assert.Positive(t, counter.calls.Load(), "transport wrapper RoundTrip should have been called at least once in gateway path")
}
40 changes: 31 additions & 9 deletions pkg/model/provider/bedrock/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/environment"
"github.com/docker/docker-agent/pkg/httpclient"
"github.com/docker/docker-agent/pkg/model/provider/base"
"github.com/docker/docker-agent/pkg/model/provider/options"
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
Expand Down Expand Up @@ -60,7 +61,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
opt(&globalOptions)
}

// Check for bearer token - use token_key if specified, otherwise try AWS_BEARER_TOKEN_BEDROCK.
// Check for bearer token
// Bearer token is optional: if not provided, falls back to standard AWS credential chain (SigV4).
//
// NOTE: Manual token handling is required because aws-sdk-go-v2's default credential chain
Expand All @@ -77,6 +78,28 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
bearerToken, _ = env.Get(ctx, "AWS_BEARER_TOKEN_BEDROCK")
}

// Build the docker-agent HTTP client (OTel instrumentation, SSE decompression,
// Desktop proxy support) so transport-level concerns apply to Bedrock too.
httpClient := httpclient.NewHTTPClient(ctx)

// If a bearer token is set, chain it on top of the base transport so auth
// headers are injected without replacing the rest of the transport stack.
if bearerToken != "" {
httpClient.Transport = &bearerTokenTransport{
token: bearerToken,
base: httpClient.Transport,
}
}

// Apply the transport wrapper, if registered, over the full chain.
if w := globalOptions.TransportWrapper(); w != nil {
if wrapped := w(httpClient.Transport); wrapped != nil {
httpClient.Transport = wrapped
} else {
slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport")
}
}

// Build AWS config using default credential chain
awsCfg, err := buildAWSConfig(ctx, cfg, env)
if err != nil {
Expand All @@ -94,19 +117,18 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
})
}

// If bearer token is set, use it instead of SigV4
// Inject our HTTP client (which carries OTel, SSE, bearer token if set, and
// any caller-registered transport wrapper) into the Bedrock runtime options.
if bearerToken != "" {
slog.DebugContext(ctx, "Bedrock using bearer token authentication")
clientOpts = append(clientOpts, func(o *bedrockruntime.Options) {
// Use anonymous credentials to skip SigV4 signing
o.Credentials = aws.AnonymousCredentials{}
// Add bearer token via custom HTTP client
o.HTTPClient = &http.Client{
Transport: &bearerTokenTransport{
token: bearerToken,
base: http.DefaultTransport,
},
}
o.HTTPClient = httpClient
})
} else {
clientOpts = append(clientOpts, func(o *bedrockruntime.Options) {
o.HTTPClient = httpClient
})
}

Expand Down
43 changes: 36 additions & 7 deletions pkg/model/provider/gemini/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,20 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
project string
location string
)
// project/location take priority over API key, like in the genai client.
if cfg.ProviderOpts["project"] != nil || cfg.ProviderOpts["location"] != nil {
// Determine whether Vertex AI would normally be used, then check whether
// an HTTP transport wrapper forces a fallback to BackendGeminiAPI.
// The Vertex AI backend relies on ADC-managed HTTP clients that bypass
// http.RoundTripper, so the wrapper cannot be applied there.
_, useVertexAIEnv := env.Get(ctx, "GOOGLE_GENAI_USE_VERTEXAI")
wantVertexAI := cfg.ProviderOpts["project"] != nil || cfg.ProviderOpts["location"] != nil || useVertexAIEnv
useVertexAI := wantVertexAI && globalOptions.TransportWrapper() == nil

if wantVertexAI && !useVertexAI {
slog.DebugContext(ctx, "Vertex AI requested but HTTP transport wrapper is set, falling back to GeminiAPI backend")
}

switch {
case useVertexAI && (cfg.ProviderOpts["project"] != nil || cfg.ProviderOpts["location"] != nil):
var err error

project, err = environment.Expand(ctx, providerOption(cfg, "project"), env)
Expand All @@ -82,13 +94,13 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
}

backend = genai.BackendVertexAI
httpClient = nil // Use default client
} else if _, exist := env.Get(ctx, "GOOGLE_GENAI_USE_VERTEXAI"); exist {
httpClient = nil // Use ADC-managed client
case useVertexAI:
project, _ = env.Get(ctx, "GOOGLE_CLOUD_PROJECT")
location, _ = env.Get(ctx, "GOOGLE_CLOUD_LOCATION")
backend = genai.BackendVertexAI
httpClient = nil // Use default client
} else {
httpClient = nil // Use ADC-managed client
default:
if value, exist := env.Get(ctx, "GEMINI_API_KEY"); exist {
apiKey = value
}
Expand All @@ -103,6 +115,14 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
httpClient = httpclient.NewHTTPClient(ctx)
}

if w := globalOptions.TransportWrapper(); w != nil {
if wrapped := w(httpClient.Transport); wrapped != nil {
httpClient.Transport = wrapped
} else {
slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport")
}
}

client, err := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: apiKey,
Project: project,
Expand Down Expand Up @@ -168,10 +188,19 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
}
}

gatewayHTTPClient := httpclient.NewHTTPClient(ctx, httpOptions...)
if w := globalOptions.TransportWrapper(); w != nil {
if wrapped := w(gatewayHTTPClient.Transport); wrapped != nil {
gatewayHTTPClient.Transport = wrapped
} else {
slog.WarnContext(ctx, "HTTP transport wrapper returned nil; using original transport")
}
}

return genai.NewClient(ctx, &genai.ClientConfig{
APIKey: authToken,
Backend: genai.BackendGeminiAPI,
HTTPClient: httpclient.NewHTTPClient(ctx, httpOptions...),
HTTPClient: gatewayHTTPClient,
HTTPOptions: httpOpts,
})
}
Expand Down
Loading
Loading