diff --git a/bridge.go b/bridge.go index 511a5edc..72689cb3 100644 --- a/bridge.go +++ b/bridge.go @@ -198,7 +198,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC InitiatorID: actor.ID, Metadata: actor.Metadata, Model: interceptor.Model(), - Provider: p.Name(), + Provider: p.Type(), + ProviderName: p.Name(), UserAgent: r.UserAgent(), Client: string(client), ClientSessionID: sessionID, diff --git a/config/config.go b/config/config.go index 3e8cdf47..48f29bb3 100644 --- a/config/config.go +++ b/config/config.go @@ -9,6 +9,8 @@ const ( ) type Anthropic struct { + // Name is the provider instance name. If empty, defaults to "anthropic". + Name string BaseURL string Key string APIDumpDir string @@ -32,6 +34,8 @@ type AWSBedrock struct { } type OpenAI struct { + // Name is the provider instance name. If empty, defaults to "openai". + Name string BaseURL string Key string APIDumpDir string @@ -40,6 +44,14 @@ type OpenAI struct { ExtraHeaders map[string]string } +type Copilot struct { + // Name is the provider instance name. If empty, defaults to "copilot". + Name string + BaseURL string + APIDumpDir string + CircuitBreaker *CircuitBreaker +} + // CircuitBreaker holds configuration for circuit breakers. type CircuitBreaker struct { // MaxRequests is the maximum number of requests allowed in half-open state. @@ -67,9 +79,3 @@ func DefaultCircuitBreaker() CircuitBreaker { MaxRequests: 3, } } - -type Copilot struct { - BaseURL string - APIDumpDir string - CircuitBreaker *CircuitBreaker -} diff --git a/internal/testutil/mockprovider.go b/internal/testutil/mockprovider.go index a21ac6a4..06b8a2f2 100644 --- a/internal/testutil/mockprovider.go +++ b/internal/testutil/mockprovider.go @@ -17,6 +17,7 @@ type MockProvider struct { InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) } +func (m *MockProvider) Type() string { return m.Name_ } func (m *MockProvider) Name() string { return m.Name_ } func (m *MockProvider) BaseURL() string { return m.URL } func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.Name_) } diff --git a/provider/anthropic.go b/provider/anthropic.go index 4825db2d..c3a1235e 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -48,6 +48,9 @@ var anthropicIsFailure = func(statusCode int) bool { } func NewAnthropic(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) *Anthropic { + if cfg.Name == "" { + cfg.Name = config.ProviderAnthropic + } if cfg.BaseURL == "" { cfg.BaseURL = "https://api.anthropic.com/" } @@ -68,10 +71,14 @@ func NewAnthropic(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) *Anthropi } } -func (p *Anthropic) Name() string { +func (p *Anthropic) Type() string { return config.ProviderAnthropic } +func (p *Anthropic) Name() string { + return p.cfg.Name +} + func (p *Anthropic) RoutePrefix() string { return fmt.Sprintf("/%s", p.Name()) } diff --git a/provider/anthropic_test.go b/provider/anthropic_test.go index 5f19cdd2..3269c33d 100644 --- a/provider/anthropic_test.go +++ b/provider/anthropic_test.go @@ -14,6 +14,40 @@ import ( "github.com/coder/aibridge/internal/testutil" ) +func TestAnthropic_TypeAndName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg config.Anthropic + expectType string + expectName string + }{ + { + name: "defaults", + cfg: config.Anthropic{}, + expectType: config.ProviderAnthropic, + expectName: config.ProviderAnthropic, + }, + { + name: "custom_name", + cfg: config.Anthropic{Name: "anthropic-custom"}, + expectType: config.ProviderAnthropic, + expectName: "anthropic-custom", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := NewAnthropic(tc.cfg, nil) + assert.Equal(t, tc.expectType, p.Type()) + assert.Equal(t, tc.expectName, p.Name()) + }) + } +} + func TestAnthropic_CreateInterceptor(t *testing.T) { t.Parallel() diff --git a/provider/copilot.go b/provider/copilot.go index 99b0fca3..7f60a6b5 100644 --- a/provider/copilot.go +++ b/provider/copilot.go @@ -52,6 +52,9 @@ type Copilot struct { var _ Provider = &Copilot{} func NewCopilot(cfg config.Copilot) *Copilot { + if cfg.Name == "" { + cfg.Name = config.ProviderCopilot + } if cfg.BaseURL == "" { cfg.BaseURL = copilotBaseURL } @@ -67,10 +70,14 @@ func NewCopilot(cfg config.Copilot) *Copilot { } } -func (p *Copilot) Name() string { +func (p *Copilot) Type() string { return config.ProviderCopilot } +func (p *Copilot) Name() string { + return p.cfg.Name +} + func (p *Copilot) BaseURL() string { return p.cfg.BaseURL } diff --git a/provider/copilot_test.go b/provider/copilot_test.go index d34a4208..4fea128b 100644 --- a/provider/copilot_test.go +++ b/provider/copilot_test.go @@ -17,6 +17,40 @@ import ( var testTracer = otel.Tracer("copilot_test") +func TestCopilot_TypeAndName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg config.Copilot + expectType string + expectName string + }{ + { + name: "defaults", + cfg: config.Copilot{}, + expectType: config.ProviderCopilot, + expectName: config.ProviderCopilot, + }, + { + name: "custom_name", + cfg: config.Copilot{Name: "copilot-business"}, + expectType: config.ProviderCopilot, + expectName: "copilot-business", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := NewCopilot(tc.cfg) + assert.Equal(t, tc.expectType, p.Type()) + assert.Equal(t, tc.expectName, p.Name()) + }) + } +} + func TestCopilot_InjectAuthHeader(t *testing.T) { t.Parallel() diff --git a/provider/openai.go b/provider/openai.go index cfedf009..b794e85e 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -37,6 +37,9 @@ type OpenAI struct { var _ Provider = &OpenAI{} func NewOpenAI(cfg config.OpenAI) *OpenAI { + if cfg.Name == "" { + cfg.Name = config.ProviderOpenAI + } if cfg.BaseURL == "" { cfg.BaseURL = "https://api.openai.com/v1/" } @@ -56,10 +59,14 @@ func NewOpenAI(cfg config.OpenAI) *OpenAI { } } -func (p *OpenAI) Name() string { +func (p *OpenAI) Type() string { return config.ProviderOpenAI } +func (p *OpenAI) Name() string { + return p.cfg.Name +} + func (p *OpenAI) RoutePrefix() string { // Route prefix includes version to match default OpenAI base URL. // More detailed explanation: https://github.com/coder/aibridge/pull/174#discussion_r2782320152 diff --git a/provider/openai_test.go b/provider/openai_test.go index 18289417..dcdd2831 100644 --- a/provider/openai_test.go +++ b/provider/openai_test.go @@ -159,6 +159,40 @@ func generateResponsesPayload(payloadSize int, inputCount int, stream bool) []by return bodyBytes } +func TestOpenAI_TypeAndName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg config.OpenAI + expectType string + expectName string + }{ + { + name: "defaults", + cfg: config.OpenAI{}, + expectType: config.ProviderOpenAI, + expectName: config.ProviderOpenAI, + }, + { + name: "custom_name", + cfg: config.OpenAI{Name: "openai-custom"}, + expectType: config.ProviderOpenAI, + expectName: "openai-custom", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + p := NewOpenAI(tc.cfg) + assert.Equal(t, tc.expectType, p.Type()) + assert.Equal(t, tc.expectName, p.Name()) + }) + } +} + func TestOpenAI_CreateInterceptor(t *testing.T) { t.Parallel() diff --git a/provider/provider.go b/provider/provider.go index f2a70f18..0e9fca3e 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -44,7 +44,11 @@ var UnknownRoute = errors.New("unknown route") // OpenAI includes the version '/v1' in the base url while Anthropic does not. // More details/examples: https://github.com/coder/aibridge/pull/174#discussion_r2782320152 type Provider interface { - // Name returns the provider's name. + // Type returns the provider type: "copilot", "openai", or "anthropic". + // Multiple provider instances can share the same type. + Type() string + // Name returns the provider instance name. + // Defaults to Type() when not explicitly configured. Name() string // BaseURL defines the base URL endpoint for this provider's API. BaseURL() string diff --git a/recorder/types.go b/recorder/types.go index 20e735f4..bd726fed 100644 --- a/recorder/types.go +++ b/recorder/types.go @@ -33,6 +33,7 @@ type InterceptionRecord struct { Metadata Metadata Model string Provider string + ProviderName string StartedAt time.Time ClientSessionID *string Client string