From f43d1e0ffa0dbac21dbf347995b639cd047ca736 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Sat, 28 Mar 2026 16:58:04 +0000 Subject: [PATCH] refactor: separate OpenAI provider and interceptor configs --- config/config.go | 8 ++++++++ intercept/chatcompletions/base.go | 8 +++++--- intercept/chatcompletions/blocking.go | 6 +++++- intercept/chatcompletions/streaming.go | 6 +++++- intercept/chatcompletions/streaming_test.go | 7 +++---- intercept/responses/base.go | 8 +++++--- intercept/responses/blocking.go | 6 +++++- intercept/responses/streaming.go | 6 +++++- provider/copilot.go | 22 +++++++++------------ provider/openai.go | 17 ++++++++++------ 10 files changed, 61 insertions(+), 33 deletions(-) diff --git a/config/config.go b/config/config.go index 3e8cdf4..76c1c0a 100644 --- a/config/config.go +++ b/config/config.go @@ -31,12 +31,20 @@ type AWSBedrock struct { BaseURL string } +// OpenAI contains provider-level configuration for the OpenAI provider. type OpenAI struct { BaseURL string Key string APIDumpDir string CircuitBreaker *CircuitBreaker SendActorHeaders bool +} + +// OpenAIInterceptor contains configuration for interceptors that speak the +// OpenAI wire format. Used by any provider that uses OpenAI-compatible APIs. +type OpenAIInterceptor struct { + Key string + SendActorHeaders bool ExtraHeaders map[string]string } diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index 03d9131..7345698 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -28,8 +28,10 @@ import ( type interceptionBase struct { id uuid.UUID providerName string + baseURL string + apiDumpDir string req *ChatCompletionNewParamsWrapper - cfg config.OpenAI + cfg config.OpenAIInterceptor // clientHeaders are the original HTTP headers from the client request. clientHeaders http.Header @@ -43,7 +45,7 @@ type interceptionBase struct { } func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService { - opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)} + opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.baseURL)} // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. @@ -63,7 +65,7 @@ func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService } // Add API dump middleware if configured - if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.ProviderName(), i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.apiDumpDir, i.ProviderName(), i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index 9e398d0..5644726 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -32,7 +32,9 @@ func NewBlockingInterceptor( id uuid.UUID, req *ChatCompletionNewParamsWrapper, providerName string, - cfg config.OpenAI, + baseURL string, + apiDumpDir string, + cfg config.OpenAIInterceptor, clientHeaders http.Header, authHeaderName string, tracer trace.Tracer, @@ -40,6 +42,8 @@ func NewBlockingInterceptor( return &BlockingInterception{interceptionBase: interceptionBase{ id: id, providerName: providerName, + baseURL: baseURL, + apiDumpDir: apiDumpDir, req: req, cfg: cfg, clientHeaders: clientHeaders, diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index 1d705e3..144888c 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -37,7 +37,9 @@ func NewStreamingInterceptor( id uuid.UUID, req *ChatCompletionNewParamsWrapper, providerName string, - cfg config.OpenAI, + baseURL string, + apiDumpDir string, + cfg config.OpenAIInterceptor, clientHeaders http.Header, authHeaderName string, tracer trace.Tracer, @@ -45,6 +47,8 @@ func NewStreamingInterceptor( return &StreamingInterception{interceptionBase: interceptionBase{ id: id, providerName: providerName, + baseURL: baseURL, + apiDumpDir: apiDumpDir, req: req, cfg: cfg, clientHeaders: clientHeaders, diff --git a/intercept/chatcompletions/streaming_test.go b/intercept/chatcompletions/streaming_test.go index 54c4733..7b6214a 100644 --- a/intercept/chatcompletions/streaming_test.go +++ b/intercept/chatcompletions/streaming_test.go @@ -66,9 +66,8 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) { t.Cleanup(mockServer.Close) // Create interceptor with mock server URL - cfg := config.OpenAI{ - BaseURL: mockServer.URL, - Key: "test-key", + cfg := config.OpenAIInterceptor{ + Key: "test-key", } req := &ChatCompletionNewParamsWrapper{ @@ -86,7 +85,7 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) { httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil) tracer := otel.Tracer("test") - interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer) + interceptor := NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, mockServer.URL, "", cfg, httpReq.Header, "Authorization", tracer) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) interceptor.Setup(logger, &testutil.MockRecorder{}, nil) diff --git a/intercept/responses/base.go b/intercept/responses/base.go index 1f50fce..e63d314 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -38,12 +38,14 @@ const ( type responsesInterceptionBase struct { id uuid.UUID providerName string + baseURL string + apiDumpDir string // clientHeaders are the original HTTP headers from the client request. clientHeaders http.Header authHeaderName string reqPayload ResponsesRequestPayload - cfg config.OpenAI + cfg config.OpenAIInterceptor recorder recorder.Recorder mcpProxy mcp.ServerProxier @@ -52,7 +54,7 @@ type responsesInterceptionBase struct { } func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { - opts := []option.RequestOption{option.WithBaseURL(i.cfg.BaseURL), option.WithAPIKey(i.cfg.Key)} + opts := []option.RequestOption{option.WithBaseURL(i.baseURL), option.WithAPIKey(i.cfg.Key)} // Add extra headers if configured. // Some providers require additional headers that are not added by the SDK. @@ -72,7 +74,7 @@ func (i *responsesInterceptionBase) newResponsesService() responses.ResponseServ } // Add API dump middleware if configured - if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.ProviderName(), i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + if mw := apidump.NewBridgeMiddleware(i.apiDumpDir, i.ProviderName(), i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { opts = append(opts, option.WithMiddleware(mw)) } diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index d64adf9..f2e0a55 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -29,7 +29,9 @@ func NewBlockingInterceptor( id uuid.UUID, reqPayload ResponsesRequestPayload, providerName string, - cfg config.OpenAI, + baseURL string, + apiDumpDir string, + cfg config.OpenAIInterceptor, clientHeaders http.Header, authHeaderName string, tracer trace.Tracer, @@ -38,6 +40,8 @@ func NewBlockingInterceptor( responsesInterceptionBase: responsesInterceptionBase{ id: id, providerName: providerName, + baseURL: baseURL, + apiDumpDir: apiDumpDir, reqPayload: reqPayload, cfg: cfg, clientHeaders: clientHeaders, diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 359f82e..606c6ee 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -36,7 +36,9 @@ func NewStreamingInterceptor( id uuid.UUID, reqPayload ResponsesRequestPayload, providerName string, - cfg config.OpenAI, + baseURL string, + apiDumpDir string, + cfg config.OpenAIInterceptor, clientHeaders http.Header, authHeaderName string, tracer trace.Tracer, @@ -45,6 +47,8 @@ func NewStreamingInterceptor( responsesInterceptionBase: responsesInterceptionBase{ id: id, providerName: providerName, + baseURL: baseURL, + apiDumpDir: apiDumpDir, reqPayload: reqPayload, cfg: cfg, clientHeaders: clientHeaders, diff --git a/provider/copilot.go b/provider/copilot.go index abf19c8..00be843 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -127,15 +127,11 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac id := uuid.New() - // Build config for the interceptor using the per-request key. - // Copilot's API is OpenAI-compatible, so it uses the OpenAI interceptors - // that require a config.OpenAI. - cfg := config.OpenAI{ - BaseURL: p.cfg.BaseURL, - Key: key, - APIDumpDir: p.cfg.APIDumpDir, - CircuitBreaker: p.cfg.CircuitBreaker, - ExtraHeaders: extractCopilotHeaders(r), + // Build interceptor config using the per-request key. + // Copilot's API is OpenAI-compatible, so it uses the OpenAI interceptors. + interceptorCfg := config.OpenAIInterceptor{ + Key: key, + ExtraHeaders: extractCopilotHeaders(r), } var interceptor intercept.Interceptor @@ -149,9 +145,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } case routeCopilotResponses: @@ -165,9 +161,9 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac } if reqPayload.Stream() { - interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } default: diff --git a/provider/openai.go b/provider/openai.go index a1624e1..ce84170 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -97,7 +97,6 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace var interceptor intercept.Interceptor - cfg := p.cfg // At this point the request contains only LLM provider headers. Any // Coder-specific authentication has already been stripped. // @@ -106,8 +105,14 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace // // In BYOK mode the user's credential is in Authorization. Replace // the centralized key with it so it is forwarded upstream. + key := p.cfg.Key if token := utils.ExtractBearerToken(r.Header.Get("Authorization")); token != "" { - cfg.Key = token + key = token + } + + interceptorCfg := config.OpenAIInterceptor{ + Key: key, + SendActorHeaders: p.cfg.SendActorHeaders, } path := strings.TrimPrefix(r.URL.Path, p.RoutePrefix()) @@ -119,9 +124,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } case routeResponses: @@ -134,9 +139,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace return nil, fmt.Errorf("unmarshal request body: %w", err) } if reqPayload.Stream() { - interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewStreamingInterceptor(id, reqPayload, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } else { - interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), cfg, r.Header, p.AuthHeader(), tracer) + interceptor = responses.NewBlockingInterceptor(id, reqPayload, p.Name(), p.cfg.BaseURL, p.cfg.APIDumpDir, interceptorCfg, r.Header, p.AuthHeader(), tracer) } default: