diff --git a/bridge.go b/bridge.go index 72689cb..4d79fba 100644 --- a/bridge.go +++ b/bridge.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/url" + "regexp" "strings" "sync" "sync/atomic" @@ -57,6 +58,25 @@ type RequestBridge struct { var _ http.Handler = &RequestBridge{} +// validProviderName matches names containing only lowercase alphanumeric characters and hyphens. +var validProviderName = regexp.MustCompile(`^[a-z0-9]+(-[a-z0-9]+)*$`) + +// validateProviders checks that provider names are valid and unique. +func validateProviders(providers []provider.Provider) error { + names := make(map[string]bool, len(providers)) + for _, prov := range providers { + name := prov.Name() + if !validProviderName.MatchString(name) { + return fmt.Errorf("invalid provider name %q: must contain only lowercase alphanumeric characters and hyphens", name) + } + if names[name] { + return fmt.Errorf("duplicate provider name: %q", name) + } + names[name] = true + } + return nil +} + // NewRequestBridge creates a new *[RequestBridge] and registers the HTTP routes defined by the given providers. // Any routes which are requested but not registered will be reverse-proxied to the upstream service. // @@ -67,6 +87,10 @@ var _ http.Handler = &RequestBridge{} // Circuit breaker configuration is obtained from each provider's CircuitBreakerConfig() method. // Providers returning nil will not have circuit breaker protection. func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) (*RequestBridge, error) { + if err := validateProviders(providers); err != nil { + return nil, err + } + mux := http.NewServeMux() for _, prov := range providers { diff --git a/bridge_test.go b/bridge_test.go index 1709be1..161f3f3 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -13,6 +13,125 @@ import ( "github.com/stretchr/testify/require" ) +func TestValidateProvider_Names(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + providers []provider.Provider + expectErr string + }{ + { + name: "all_supported_providers", + providers: []provider.Provider{ + NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}), + NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil), + NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), + NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), + NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}), + }, + }, + { + name: "default_names_and_base_urls", + providers: []provider.Provider{ + NewOpenAIProvider(config.OpenAI{}), + NewAnthropicProvider(config.Anthropic{}, nil), + NewCopilotProvider(config.Copilot{}), + }, + }, + { + name: "multiple_copilot_instances", + providers: []provider.Provider{ + NewCopilotProvider(config.Copilot{}), + NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), + NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}), + }, + }, + { + name: "name_with_slashes", + providers: []provider.Provider{ + NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}), + }, + expectErr: "invalid provider name", + }, + { + name: "name_with_spaces", + providers: []provider.Provider{ + NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}), + }, + expectErr: "invalid provider name", + }, + { + name: "name_with_uppercase", + providers: []provider.Provider{ + NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}), + }, + expectErr: "invalid provider name", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := validateProviders(tc.providers) + if tc.expectErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectErr) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateProvider_DuplicateNames(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + providers []provider.Provider + expectErr string + }{ + { + name: "unique_names", + providers: []provider.Provider{ + NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), + NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}), + }, + }, + { + name: "duplicate_base_url_different_names", + providers: []provider.Provider{ + NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), + NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}), + }, + }, + { + name: "duplicate_name", + providers: []provider.Provider{ + NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}), + NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}), + }, + expectErr: "duplicate provider name", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := validateProviders(tc.providers) + if tc.expectErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectErr) + } else { + require.NoError(t, err) + } + }) + } +} + func TestPassthroughRoutesForProviders(t *testing.T) { t.Parallel()