Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 31 additions & 24 deletions controlplane/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -1768,46 +1768,55 @@ func (cp *ControlPlane) drainAfterUpgrade() {
// within that org. Flight has no `database` param, so there is no catalog
// selection here — the per-user default catalog applies.
type cpFlightCredentialValidator struct {
cp *ControlPlane
orgProvider *orgRoutedSessionProvider
cp *ControlPlane
}

func (v *cpFlightCredentialValidator) ValidateCredentials(username, password string) bool {
return v.ValidateCredentialsForSNI("", username, password)
}

// ValidateCredentialsForSNI authenticates (username, password) against the org
// the connection's managed hostname (SNI) resolves to. It does NOT stash the
// resolved org anywhere keyed by username — session routing re-derives the org
// from the same SNI at create time (orgRoutedSessionProvider.resolveOrg), so the
// authenticated principal stays bound to this connection's hostname rather than
// a shared username→org map that two tenants could collide on.
func (v *cpFlightCredentialValidator) ValidateCredentialsForSNI(sni, username, password string) bool {
cp := v.cp
sniPrefix, isManaged := cp.extractOrgFromSNI(sni)
if !isManaged {
// A username alone can collide across orgs, so identity now requires a
// A username alone can collide across orgs, so identity requires a
// managed hostname — there is no username-scan fallback.
slog.Warn("Flight auth rejected: SNI does not match a managed hostname.",
"sni", sni, "expected", cp.managedHostnameHint(), "user", username)
return false
}
return v.authForSNIPrefix(sni, sniPrefix, username, password)
}

// authForSNIPrefix validates (username, password) against the single org the
// SNI-derived hostname prefix resolves to (via hostname_alias, database_name,
// or DNS-safe org name — see ConfigStore.ResolveSNIPrefix).
func (v *cpFlightCredentialValidator) authForSNIPrefix(sni, sniPrefix, username, password string) bool {
cp := v.cp
orgID, dbname := cp.configStore.ResolveSNIPrefix(sniPrefix)
if orgID == "" {
slog.Warn("Flight client SNI references unknown org.",
"sni", sni, "sni_prefix", sniPrefix, "user", username)
return false
}
observeSNIRoutingResolution("flight", dbname != sniPrefix)
if !cp.configStore.ValidateOrgUser(orgID, username, password) {
return false
return cp.configStore.ValidateOrgUser(orgID, username, password)
}

// flightOrgFromContext resolves the org for a Flight session from the request
// context's SNI (the managed hostname). Used by orgRoutedSessionProvider to bind
// each session to its connection's org, mirroring the auth-time resolution.
func (cp *ControlPlane) flightOrgFromContext(ctx context.Context) (string, bool) {
return cp.resolveFlightOrgFromSNI(flightsqlingress.SNIFromContext(ctx))
}

// resolveFlightOrgFromSNI maps a TLS ServerName to its org, returning ok=false
// for unmanaged hostnames or prefixes that resolve to no org.
func (cp *ControlPlane) resolveFlightOrgFromSNI(sni string) (orgID string, ok bool) {
prefix, isManaged := cp.extractOrgFromSNI(sni)
if !isManaged {
return "", false
}
v.orgProvider.mu.Lock()
v.orgProvider.userOrg[username] = orgID
v.orgProvider.mu.Unlock()
return true
orgID, _ = cp.configStore.ResolveSNIPrefix(prefix)
return orgID, orgID != ""
}

func (cp *ControlPlane) startFlightIngress() {
Expand All @@ -1820,18 +1829,16 @@ func (cp *ControlPlane) startFlightIngress() {

switch {
case cp.configStore != nil && cp.orgRouter != nil:
// Multi-tenant: auth via config store, sessions routed per-org.
// When the client connected via a managed hostname, the SNI is
// authoritative for org routing; otherwise we fall back to scanning
// orgs by (username, password) and log a warning so legacy callers
// can be migrated.
// Multi-tenant: auth via config store, sessions routed per-org. The
// managed hostname (SNI) is authoritative for org identity at both auth
// and session-create time; there is no username-keyed routing state.
orgProvider := &orgRoutedSessionProvider{
orgRouter: cp.orgRouter,
configStore: cp.configStore,
pidSession: make(map[int32]flightOwnedSession),
userOrg: make(map[string]string),
resolveOrg: cp.flightOrgFromContext,
}
validator = &cpFlightCredentialValidator{cp: cp, orgProvider: orgProvider}
validator = &cpFlightCredentialValidator{cp: cp}
provider = orgProvider
case cp.sessions != nil:
// Single-tenant: static users map, single session manager.
Expand Down
29 changes: 22 additions & 7 deletions controlplane/flight_ingress.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,40 @@ type flightOwnedSession struct {
}

// orgRoutedSessionProvider routes Flight SQL session operations to the correct
// org's SessionManager based on the username→org mapping resolved during auth.
// org's SessionManager. The org is derived from the connection's managed
// hostname (SNI) — the same immutable per-connection identity that auth uses —
// re-resolved at session-create time via resolveOrg. There is deliberately NO
// username→org map: a username is only unique within an org, so a shared map
// keyed by username collides when two tenants share a username (the auth result
// for one connection could be overwritten by a concurrent connection's).
type orgRoutedSessionProvider struct {
orgRouter OrgRouterInterface
configStore ConfigStoreInterface
// resolveOrg resolves the org for a session from the request context's SNI.
// Injected so it can be stubbed in tests; production wires it to
// ControlPlane.flightOrgFromContext.
resolveOrg func(ctx context.Context) (orgID string, ok bool)

mu sync.RWMutex
pidSession map[int32]flightOwnedSession // pid → owning session manager
userOrg map[string]string // username → orgID (populated during auth)
}

func (p *orgRoutedSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) {
p.mu.RLock()
orgID := p.userOrg[username]
p.mu.RUnlock()
// Bind the session to the org of THIS connection's managed hostname, not a
// shared username lookup. Fail closed if the SNI no longer resolves an org.
if p.resolveOrg == nil {
return 0, nil, fmt.Errorf("flight session provider misconfigured: no org resolver")
}
orgID, ok := p.resolveOrg(ctx)
if !ok || orgID == "" {
slog.Warn("Flight SQL session: could not resolve org from connection SNI.", "username", username)
return 0, nil, fmt.Errorf("could not resolve organization for flight session")
}

_, sessions, _, ok := p.orgRouter.StackForOrg(orgID)
if !ok {
slog.Warn("Flight SQL session: no org stack for user.", "username", username, "org", orgID)
return 0, nil, fmt.Errorf("no org configured for user %q", username)
slog.Warn("Flight SQL session: no org stack for org.", "username", username, "org", orgID)
return 0, nil, fmt.Errorf("no org stack for org %q", orgID)
}

// SessionManager.resolveSessionLimits handles rebalancer defaults,
Expand Down
77 changes: 73 additions & 4 deletions controlplane/flight_ingress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,86 @@ func (r *reconnectTestOrgRouter) IsMigratingForOrg(_ string) bool { return false
func (r *reconnectTestOrgRouter) SetWarmCapacityTarget(_ int) {}
func (r *reconnectTestOrgRouter) ShutdownAll() {}

// recordingOrgRouter records the orgIDs StackForOrg is asked for. It returns no
// live stack, so CreateSession returns right after recording the routing org.
type recordingOrgRouter struct {
mu sync.Mutex
calls []string
}

func (r *recordingOrgRouter) StackForOrg(orgID string) (WorkerPool, *SessionManager, *MemoryRebalancer, bool) {
r.mu.Lock()
r.calls = append(r.calls, orgID)
r.mu.Unlock()
return nil, nil, nil, false
}
func (r *recordingOrgRouter) IcebergConfigForOrg(_ string) (server.IcebergConfig, bool) {
return server.IcebergConfig{}, false
}
func (r *recordingOrgRouter) IsMigratingForOrg(_ string) bool { return false }
func (r *recordingOrgRouter) SetWarmCapacityTarget(_ int) {}
func (r *recordingOrgRouter) ShutdownAll() {}

type testFlightOrgKey struct{}

// TestOrgRoutedSessionProviderRoutesByContextSNINotUsername proves the fix for
// the username-collision: two connections sharing the username "alice" but from
// different org hostnames each route to THEIR OWN org, because the org is
// re-derived per-connection from the context (SNI) rather than a shared
// username→org map.
func TestOrgRoutedSessionProviderRoutesByContextSNINotUsername(t *testing.T) {
router := &recordingOrgRouter{}
provider := &orgRoutedSessionProvider{
orgRouter: router,
pidSession: make(map[int32]flightOwnedSession),
resolveOrg: func(ctx context.Context) (string, bool) {
org, _ := ctx.Value(testFlightOrgKey{}).(string)
return org, org != ""
},
}

ctxA := context.WithValue(context.Background(), testFlightOrgKey{}, "org-a")
ctxB := context.WithValue(context.Background(), testFlightOrgKey{}, "org-b")
if _, _, err := provider.CreateSession(ctxA, "alice", 0, "", 0); err == nil {
t.Fatal("expected failure (no live stack)")
}
if _, _, err := provider.CreateSession(ctxB, "alice", 0, "", 0); err == nil {
t.Fatal("expected failure (no live stack)")
}

router.mu.Lock()
defer router.mu.Unlock()
if len(router.calls) != 2 || router.calls[0] != "org-a" || router.calls[1] != "org-b" {
t.Fatalf("expected StackForOrg(org-a) then StackForOrg(org-b); got %v", router.calls)
}
}

// TestOrgRoutedSessionProviderFailsClosedWhenSNIUnresolved: if the connection's
// SNI no longer resolves to an org, no session is created (fail closed).
func TestOrgRoutedSessionProviderFailsClosedWhenSNIUnresolved(t *testing.T) {
router := &recordingOrgRouter{}
provider := &orgRoutedSessionProvider{
orgRouter: router,
pidSession: make(map[int32]flightOwnedSession),
resolveOrg: func(_ context.Context) (string, bool) { return "", false },
}
if _, _, err := provider.CreateSession(context.Background(), "alice", 0, "", 0); err == nil {
t.Fatal("expected CreateSession to fail closed when org can't be resolved")
}
router.mu.Lock()
defer router.mu.Unlock()
if len(router.calls) != 0 {
t.Fatalf("expected no StackForOrg call when SNI unresolved; got %v", router.calls)
}
}

func TestOrgRoutedSessionProviderReconnectSessionUsesDurableOrgID(t *testing.T) {
router := &reconnectTestOrgRouter{
orgID: "analytics",
}
provider := &orgRoutedSessionProvider{
orgRouter: router,
pidSession: make(map[int32]flightOwnedSession),
userOrg: make(map[string]string),
configStore: nil,
}

Expand Down Expand Up @@ -64,7 +136,6 @@ func TestOrgRoutedSessionProviderDestroySessionRemovesPid(t *testing.T) {
provider := &orgRoutedSessionProvider{
orgRouter: &mockOrgRouter{sessions: sm, ok: true},
pidSession: map[int32]flightOwnedSession{42: {orgID: "test", sessions: sm}},
userOrg: make(map[string]string),
}

// Destroy known pid — should remove from map.
Expand All @@ -84,7 +155,6 @@ func TestOrgRoutedSessionProviderDestroyUnknownPidNoOp(t *testing.T) {
provider := &orgRoutedSessionProvider{
orgRouter: &mockOrgRouter{ok: true},
pidSession: make(map[int32]flightOwnedSession),
userOrg: make(map[string]string),
}

// Should not panic.
Expand All @@ -97,7 +167,6 @@ func TestOrgRoutedSessionProviderConcurrentDestroys(t *testing.T) {
provider := &orgRoutedSessionProvider{
orgRouter: &mockOrgRouter{sessions: sm, ok: true},
pidSession: make(map[int32]flightOwnedSession),
userOrg: make(map[string]string),
}

// Pre-populate
Expand Down
52 changes: 45 additions & 7 deletions controlplane/sni_kubernetes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"github.com/jackc/pgx/v5/pgconn"
"github.com/posthog/duckgres/controlplane/configstore"
"github.com/posthog/duckgres/server"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
)

func TestExtractOrgFromSNI(t *testing.T) {
Expand Down Expand Up @@ -182,10 +184,7 @@ func newFlightValidator(t *testing.T, mode string, store *fakeConfigStore) *cpFl
},
configStore: store,
}
provider := &orgRoutedSessionProvider{
userOrg: make(map[string]string),
}
return &cpFlightCredentialValidator{cp: cp, orgProvider: provider}
return &cpFlightCredentialValidator{cp: cp}
}

func newSNIControlPlane(store *fakeConfigStore) *ControlPlane {
Expand Down Expand Up @@ -333,7 +332,9 @@ func testControlPlaneTLSConfig(t *testing.T) *tls.Config {
// orgs), so a non-managed hostname always fails.

// TestFlightValidatorMatchedSNI: SNI matches, so we resolve via ResolveSNIPrefix
// and validate against that single org.
// and validate against that single org. The validator only authenticates — it
// stores no username→org routing state (session routing re-derives the org from
// the connection SNI; see flight_ingress_test.go).
func TestFlightValidatorMatchedSNI(t *testing.T) {
store := &fakeConfigStore{
resolveSNIPrefix: func(prefix string) (string, string) {
Expand All @@ -355,8 +356,45 @@ func TestFlightValidatorMatchedSNI(t *testing.T) {
t.Fatalf("expected one ResolveSNIPrefix + one ValidateOrgUser; got %d / %d",
store.resolveSNIPrefixCalls, store.validateOrgUserCalls)
}
if got := v.orgProvider.userOrg["alice"]; got != "org-acme" {
t.Fatalf("expected userOrg['alice'] = org-acme; got %q", got)
}

// flightCtxWithSNI builds a gRPC context carrying a TLS ServerName, exactly as
// the real Flight ingress sees it, so we exercise the real SNIFromContext →
// extractOrgFromSNI → ResolveSNIPrefix chain that routes a session to its org.
func flightCtxWithSNI(sni string) context.Context {
return peer.NewContext(context.Background(), &peer.Peer{
AuthInfo: credentials.TLSInfo{State: tls.ConnectionState{ServerName: sni}},
})
}

// TestFlightOrgFromContextResolvesViaSNI verifies that session routing derives
// the org from the connection's TLS SNI (the load-bearing path the no-collision
// fix relies on), and fails closed for unmanaged or missing hostnames.
func TestFlightOrgFromContextResolvesViaSNI(t *testing.T) {
store := &fakeConfigStore{
resolveSNIPrefix: func(prefix string) (string, string) {
if prefix == "acme" {
return "org-acme", "acme_db"
}
return "", ""
},
}
cp := &ControlPlane{
cfg: ControlPlaneConfig{ManagedHostnameSuffixes: []string{".dw.us.postwh.com"}},
configStore: store,
}

if org, ok := cp.flightOrgFromContext(flightCtxWithSNI("acme.dw.us.postwh.com")); !ok || org != "org-acme" {
t.Fatalf("managed SNI should resolve org-acme; got (%q, %v)", org, ok)
}
if org, ok := cp.flightOrgFromContext(flightCtxWithSNI("ghost.dw.us.postwh.com")); ok || org != "" {
t.Fatalf("unknown managed prefix must fail closed; got (%q, %v)", org, ok)
}
if _, ok := cp.flightOrgFromContext(flightCtxWithSNI("evil.example.com")); ok {
t.Fatalf("unmanaged hostname must fail closed")
}
if _, ok := cp.flightOrgFromContext(context.Background()); ok {
t.Fatalf("missing peer/SNI must fail closed")
}
}

Expand Down
10 changes: 6 additions & 4 deletions server/flightsqlingress/ingress.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ func (h *ControlPlaneFlightSQLHandler) authenticateBasicCredentials(ctx context.
// CredentialValidator path.
ok := false
if sniAware, isSNIAware := h.validator.(SNIAwareCredentialValidator); isSNIAware {
ok = sniAware.ValidateCredentialsForSNI(sniFromContext(ctx), username, password)
ok = sniAware.ValidateCredentialsForSNI(SNIFromContext(ctx), username, password)
} else {
ok = h.validator.ValidateCredentials(username, password)
}
Expand All @@ -449,9 +449,11 @@ func (h *ControlPlaneFlightSQLHandler) authenticateBasicCredentials(ctx context.
return username, nil
}

// sniFromContext returns the TLS ServerName the client sent, or "" if the
// connection isn't TLS-terminated by this server (e.g. in unit tests).
func sniFromContext(ctx context.Context) string {
// SNIFromContext returns the TLS ServerName (SNI) the client sent, or "" if the
// connection isn't TLS-terminated by this server (e.g. in unit tests). Exported
// so callers that route by org reuse the exact same extraction the auth path
// uses — auth and routing must never disagree on a connection's hostname.
func SNIFromContext(ctx context.Context) string {
pr, ok := peer.FromContext(ctx)
if !ok || pr == nil {
return ""
Expand Down
Loading