Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -57,6 +58,25 @@ type RequestBridge struct {

var _ http.Handler = &RequestBridge{}

// validProviderName matches names containing only lowercase alphanumeric characters and hyphens.
var validProviderName = regexp.MustCompile(`^[a-z0-9]+(-[a-z0-9]+)*$`)

// validateProviders checks that provider names are valid and unique.
func validateProviders(providers []provider.Provider) error {
names := make(map[string]bool, len(providers))
for _, prov := range providers {
name := prov.Name()
if !validProviderName.MatchString(name) {
return fmt.Errorf("invalid provider name %q: must contain only lowercase alphanumeric characters and hyphens", name)
}
if names[name] {
return fmt.Errorf("duplicate provider name: %q", name)
}
names[name] = true
}
return nil
}

// NewRequestBridge creates a new *[RequestBridge] and registers the HTTP routes defined by the given providers.
// Any routes which are requested but not registered will be reverse-proxied to the upstream service.
//
Expand All @@ -67,6 +87,10 @@ var _ http.Handler = &RequestBridge{}
// Circuit breaker configuration is obtained from each provider's CircuitBreakerConfig() method.
// Providers returning nil will not have circuit breaker protection.
func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec recorder.Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) (*RequestBridge, error) {
if err := validateProviders(providers); err != nil {
return nil, err
}

mux := http.NewServeMux()

for _, prov := range providers {
Expand Down
119 changes: 119 additions & 0 deletions bridge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,125 @@ import (
"github.com/stretchr/testify/require"
)

func TestValidateProvider_Names(t *testing.T) {
t.Parallel()

tests := []struct {
name string
providers []provider.Provider
expectErr string
}{
{
name: "all_supported_providers",
providers: []provider.Provider{
NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
},
},
{
name: "default_names_and_base_urls",
providers: []provider.Provider{
NewOpenAIProvider(config.OpenAI{}),
NewAnthropicProvider(config.Anthropic{}, nil),
NewCopilotProvider(config.Copilot{}),
},
},
{
name: "multiple_copilot_instances",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{}),
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
},
},
{
name: "name_with_slashes",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "invalid provider name",
},
{
name: "name_with_spaces",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "invalid provider name",
},
{
name: "name_with_uppercase",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "invalid provider name",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

err := validateProviders(tc.providers)
if tc.expectErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectErr)
} else {
require.NoError(t, err)
}
})
}
}

func TestValidateProvider_DuplicateNames(t *testing.T) {
t.Parallel()

tests := []struct {
name string
providers []provider.Provider
expectErr string
}{
{
name: "unique_names",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
},
},
{
name: "duplicate_base_url_different_names",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}),
},
},
{
name: "duplicate_name",
providers: []provider.Provider{
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "duplicate provider name",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

err := validateProviders(tc.providers)
if tc.expectErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectErr)
} else {
require.NoError(t, err)
}
})
}
}

func TestPassthroughRoutesForProviders(t *testing.T) {
t.Parallel()

Expand Down
Loading