diff --git a/controlplane/control.go b/controlplane/control.go index 7ea7cd8b..8ec34cb6 100644 --- a/controlplane/control.go +++ b/controlplane/control.go @@ -1768,32 +1768,29 @@ func (cp *ControlPlane) drainAfterUpgrade() { // within that org. Flight has no `database` param, so there is no catalog // selection here — the per-user default catalog applies. type cpFlightCredentialValidator struct { - cp *ControlPlane - orgProvider *orgRoutedSessionProvider + cp *ControlPlane } func (v *cpFlightCredentialValidator) ValidateCredentials(username, password string) bool { return v.ValidateCredentialsForSNI("", username, password) } +// ValidateCredentialsForSNI authenticates (username, password) against the org +// the connection's managed hostname (SNI) resolves to. It does NOT stash the +// resolved org anywhere keyed by username — session routing re-derives the org +// from the same SNI at create time (orgRoutedSessionProvider.resolveOrg), so the +// authenticated principal stays bound to this connection's hostname rather than +// a shared username→org map that two tenants could collide on. func (v *cpFlightCredentialValidator) ValidateCredentialsForSNI(sni, username, password string) bool { cp := v.cp sniPrefix, isManaged := cp.extractOrgFromSNI(sni) if !isManaged { - // A username alone can collide across orgs, so identity now requires a + // A username alone can collide across orgs, so identity requires a // managed hostname — there is no username-scan fallback. slog.Warn("Flight auth rejected: SNI does not match a managed hostname.", "sni", sni, "expected", cp.managedHostnameHint(), "user", username) return false } - return v.authForSNIPrefix(sni, sniPrefix, username, password) -} - -// authForSNIPrefix validates (username, password) against the single org the -// SNI-derived hostname prefix resolves to (via hostname_alias, database_name, -// or DNS-safe org name — see ConfigStore.ResolveSNIPrefix). -func (v *cpFlightCredentialValidator) authForSNIPrefix(sni, sniPrefix, username, password string) bool { - cp := v.cp orgID, dbname := cp.configStore.ResolveSNIPrefix(sniPrefix) if orgID == "" { slog.Warn("Flight client SNI references unknown org.", @@ -1801,13 +1798,25 @@ func (v *cpFlightCredentialValidator) authForSNIPrefix(sni, sniPrefix, username, return false } observeSNIRoutingResolution("flight", dbname != sniPrefix) - if !cp.configStore.ValidateOrgUser(orgID, username, password) { - return false + return cp.configStore.ValidateOrgUser(orgID, username, password) +} + +// flightOrgFromContext resolves the org for a Flight session from the request +// context's SNI (the managed hostname). Used by orgRoutedSessionProvider to bind +// each session to its connection's org, mirroring the auth-time resolution. +func (cp *ControlPlane) flightOrgFromContext(ctx context.Context) (string, bool) { + return cp.resolveFlightOrgFromSNI(flightsqlingress.SNIFromContext(ctx)) +} + +// resolveFlightOrgFromSNI maps a TLS ServerName to its org, returning ok=false +// for unmanaged hostnames or prefixes that resolve to no org. +func (cp *ControlPlane) resolveFlightOrgFromSNI(sni string) (orgID string, ok bool) { + prefix, isManaged := cp.extractOrgFromSNI(sni) + if !isManaged { + return "", false } - v.orgProvider.mu.Lock() - v.orgProvider.userOrg[username] = orgID - v.orgProvider.mu.Unlock() - return true + orgID, _ = cp.configStore.ResolveSNIPrefix(prefix) + return orgID, orgID != "" } func (cp *ControlPlane) startFlightIngress() { @@ -1820,18 +1829,16 @@ func (cp *ControlPlane) startFlightIngress() { switch { case cp.configStore != nil && cp.orgRouter != nil: - // Multi-tenant: auth via config store, sessions routed per-org. - // When the client connected via a managed hostname, the SNI is - // authoritative for org routing; otherwise we fall back to scanning - // orgs by (username, password) and log a warning so legacy callers - // can be migrated. + // Multi-tenant: auth via config store, sessions routed per-org. The + // managed hostname (SNI) is authoritative for org identity at both auth + // and session-create time; there is no username-keyed routing state. orgProvider := &orgRoutedSessionProvider{ orgRouter: cp.orgRouter, configStore: cp.configStore, pidSession: make(map[int32]flightOwnedSession), - userOrg: make(map[string]string), + resolveOrg: cp.flightOrgFromContext, } - validator = &cpFlightCredentialValidator{cp: cp, orgProvider: orgProvider} + validator = &cpFlightCredentialValidator{cp: cp} provider = orgProvider case cp.sessions != nil: // Single-tenant: static users map, single session manager. diff --git a/controlplane/flight_ingress.go b/controlplane/flight_ingress.go index ea51f186..98ca9984 100644 --- a/controlplane/flight_ingress.go +++ b/controlplane/flight_ingress.go @@ -65,25 +65,40 @@ type flightOwnedSession struct { } // orgRoutedSessionProvider routes Flight SQL session operations to the correct -// org's SessionManager based on the username→org mapping resolved during auth. +// org's SessionManager. The org is derived from the connection's managed +// hostname (SNI) — the same immutable per-connection identity that auth uses — +// re-resolved at session-create time via resolveOrg. There is deliberately NO +// username→org map: a username is only unique within an org, so a shared map +// keyed by username collides when two tenants share a username (the auth result +// for one connection could be overwritten by a concurrent connection's). type orgRoutedSessionProvider struct { orgRouter OrgRouterInterface configStore ConfigStoreInterface + // resolveOrg resolves the org for a session from the request context's SNI. + // Injected so it can be stubbed in tests; production wires it to + // ControlPlane.flightOrgFromContext. + resolveOrg func(ctx context.Context) (orgID string, ok bool) mu sync.RWMutex pidSession map[int32]flightOwnedSession // pid → owning session manager - userOrg map[string]string // username → orgID (populated during auth) } func (p *orgRoutedSessionProvider) CreateSession(ctx context.Context, username string, pid int32, memoryLimit string, threads int) (int32, *flightclient.FlightExecutor, error) { - p.mu.RLock() - orgID := p.userOrg[username] - p.mu.RUnlock() + // Bind the session to the org of THIS connection's managed hostname, not a + // shared username lookup. Fail closed if the SNI no longer resolves an org. + if p.resolveOrg == nil { + return 0, nil, fmt.Errorf("flight session provider misconfigured: no org resolver") + } + orgID, ok := p.resolveOrg(ctx) + if !ok || orgID == "" { + slog.Warn("Flight SQL session: could not resolve org from connection SNI.", "username", username) + return 0, nil, fmt.Errorf("could not resolve organization for flight session") + } _, sessions, _, ok := p.orgRouter.StackForOrg(orgID) if !ok { - slog.Warn("Flight SQL session: no org stack for user.", "username", username, "org", orgID) - return 0, nil, fmt.Errorf("no org configured for user %q", username) + slog.Warn("Flight SQL session: no org stack for org.", "username", username, "org", orgID) + return 0, nil, fmt.Errorf("no org stack for org %q", orgID) } // SessionManager.resolveSessionLimits handles rebalancer defaults, diff --git a/controlplane/flight_ingress_test.go b/controlplane/flight_ingress_test.go index 7ba7ecf2..5221aaca 100644 --- a/controlplane/flight_ingress_test.go +++ b/controlplane/flight_ingress_test.go @@ -27,6 +27,79 @@ func (r *reconnectTestOrgRouter) IsMigratingForOrg(_ string) bool { return false func (r *reconnectTestOrgRouter) SetWarmCapacityTarget(_ int) {} func (r *reconnectTestOrgRouter) ShutdownAll() {} +// recordingOrgRouter records the orgIDs StackForOrg is asked for. It returns no +// live stack, so CreateSession returns right after recording the routing org. +type recordingOrgRouter struct { + mu sync.Mutex + calls []string +} + +func (r *recordingOrgRouter) StackForOrg(orgID string) (WorkerPool, *SessionManager, *MemoryRebalancer, bool) { + r.mu.Lock() + r.calls = append(r.calls, orgID) + r.mu.Unlock() + return nil, nil, nil, false +} +func (r *recordingOrgRouter) IcebergConfigForOrg(_ string) (server.IcebergConfig, bool) { + return server.IcebergConfig{}, false +} +func (r *recordingOrgRouter) IsMigratingForOrg(_ string) bool { return false } +func (r *recordingOrgRouter) SetWarmCapacityTarget(_ int) {} +func (r *recordingOrgRouter) ShutdownAll() {} + +type testFlightOrgKey struct{} + +// TestOrgRoutedSessionProviderRoutesByContextSNINotUsername proves the fix for +// the username-collision: two connections sharing the username "alice" but from +// different org hostnames each route to THEIR OWN org, because the org is +// re-derived per-connection from the context (SNI) rather than a shared +// username→org map. +func TestOrgRoutedSessionProviderRoutesByContextSNINotUsername(t *testing.T) { + router := &recordingOrgRouter{} + provider := &orgRoutedSessionProvider{ + orgRouter: router, + pidSession: make(map[int32]flightOwnedSession), + resolveOrg: func(ctx context.Context) (string, bool) { + org, _ := ctx.Value(testFlightOrgKey{}).(string) + return org, org != "" + }, + } + + ctxA := context.WithValue(context.Background(), testFlightOrgKey{}, "org-a") + ctxB := context.WithValue(context.Background(), testFlightOrgKey{}, "org-b") + if _, _, err := provider.CreateSession(ctxA, "alice", 0, "", 0); err == nil { + t.Fatal("expected failure (no live stack)") + } + if _, _, err := provider.CreateSession(ctxB, "alice", 0, "", 0); err == nil { + t.Fatal("expected failure (no live stack)") + } + + router.mu.Lock() + defer router.mu.Unlock() + if len(router.calls) != 2 || router.calls[0] != "org-a" || router.calls[1] != "org-b" { + t.Fatalf("expected StackForOrg(org-a) then StackForOrg(org-b); got %v", router.calls) + } +} + +// TestOrgRoutedSessionProviderFailsClosedWhenSNIUnresolved: if the connection's +// SNI no longer resolves to an org, no session is created (fail closed). +func TestOrgRoutedSessionProviderFailsClosedWhenSNIUnresolved(t *testing.T) { + router := &recordingOrgRouter{} + provider := &orgRoutedSessionProvider{ + orgRouter: router, + pidSession: make(map[int32]flightOwnedSession), + resolveOrg: func(_ context.Context) (string, bool) { return "", false }, + } + if _, _, err := provider.CreateSession(context.Background(), "alice", 0, "", 0); err == nil { + t.Fatal("expected CreateSession to fail closed when org can't be resolved") + } + router.mu.Lock() + defer router.mu.Unlock() + if len(router.calls) != 0 { + t.Fatalf("expected no StackForOrg call when SNI unresolved; got %v", router.calls) + } +} + func TestOrgRoutedSessionProviderReconnectSessionUsesDurableOrgID(t *testing.T) { router := &reconnectTestOrgRouter{ orgID: "analytics", @@ -34,7 +107,6 @@ func TestOrgRoutedSessionProviderReconnectSessionUsesDurableOrgID(t *testing.T) provider := &orgRoutedSessionProvider{ orgRouter: router, pidSession: make(map[int32]flightOwnedSession), - userOrg: make(map[string]string), configStore: nil, } @@ -64,7 +136,6 @@ func TestOrgRoutedSessionProviderDestroySessionRemovesPid(t *testing.T) { provider := &orgRoutedSessionProvider{ orgRouter: &mockOrgRouter{sessions: sm, ok: true}, pidSession: map[int32]flightOwnedSession{42: {orgID: "test", sessions: sm}}, - userOrg: make(map[string]string), } // Destroy known pid — should remove from map. @@ -84,7 +155,6 @@ func TestOrgRoutedSessionProviderDestroyUnknownPidNoOp(t *testing.T) { provider := &orgRoutedSessionProvider{ orgRouter: &mockOrgRouter{ok: true}, pidSession: make(map[int32]flightOwnedSession), - userOrg: make(map[string]string), } // Should not panic. @@ -97,7 +167,6 @@ func TestOrgRoutedSessionProviderConcurrentDestroys(t *testing.T) { provider := &orgRoutedSessionProvider{ orgRouter: &mockOrgRouter{sessions: sm, ok: true}, pidSession: make(map[int32]flightOwnedSession), - userOrg: make(map[string]string), } // Pre-populate diff --git a/controlplane/sni_kubernetes_test.go b/controlplane/sni_kubernetes_test.go index edf2bf3c..4725072d 100644 --- a/controlplane/sni_kubernetes_test.go +++ b/controlplane/sni_kubernetes_test.go @@ -14,6 +14,8 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/posthog/duckgres/controlplane/configstore" "github.com/posthog/duckgres/server" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" ) func TestExtractOrgFromSNI(t *testing.T) { @@ -182,10 +184,7 @@ func newFlightValidator(t *testing.T, mode string, store *fakeConfigStore) *cpFl }, configStore: store, } - provider := &orgRoutedSessionProvider{ - userOrg: make(map[string]string), - } - return &cpFlightCredentialValidator{cp: cp, orgProvider: provider} + return &cpFlightCredentialValidator{cp: cp} } func newSNIControlPlane(store *fakeConfigStore) *ControlPlane { @@ -333,7 +332,9 @@ func testControlPlaneTLSConfig(t *testing.T) *tls.Config { // orgs), so a non-managed hostname always fails. // TestFlightValidatorMatchedSNI: SNI matches, so we resolve via ResolveSNIPrefix -// and validate against that single org. +// and validate against that single org. The validator only authenticates — it +// stores no username→org routing state (session routing re-derives the org from +// the connection SNI; see flight_ingress_test.go). func TestFlightValidatorMatchedSNI(t *testing.T) { store := &fakeConfigStore{ resolveSNIPrefix: func(prefix string) (string, string) { @@ -355,8 +356,45 @@ func TestFlightValidatorMatchedSNI(t *testing.T) { t.Fatalf("expected one ResolveSNIPrefix + one ValidateOrgUser; got %d / %d", store.resolveSNIPrefixCalls, store.validateOrgUserCalls) } - if got := v.orgProvider.userOrg["alice"]; got != "org-acme" { - t.Fatalf("expected userOrg['alice'] = org-acme; got %q", got) +} + +// flightCtxWithSNI builds a gRPC context carrying a TLS ServerName, exactly as +// the real Flight ingress sees it, so we exercise the real SNIFromContext → +// extractOrgFromSNI → ResolveSNIPrefix chain that routes a session to its org. +func flightCtxWithSNI(sni string) context.Context { + return peer.NewContext(context.Background(), &peer.Peer{ + AuthInfo: credentials.TLSInfo{State: tls.ConnectionState{ServerName: sni}}, + }) +} + +// TestFlightOrgFromContextResolvesViaSNI verifies that session routing derives +// the org from the connection's TLS SNI (the load-bearing path the no-collision +// fix relies on), and fails closed for unmanaged or missing hostnames. +func TestFlightOrgFromContextResolvesViaSNI(t *testing.T) { + store := &fakeConfigStore{ + resolveSNIPrefix: func(prefix string) (string, string) { + if prefix == "acme" { + return "org-acme", "acme_db" + } + return "", "" + }, + } + cp := &ControlPlane{ + cfg: ControlPlaneConfig{ManagedHostnameSuffixes: []string{".dw.us.postwh.com"}}, + configStore: store, + } + + if org, ok := cp.flightOrgFromContext(flightCtxWithSNI("acme.dw.us.postwh.com")); !ok || org != "org-acme" { + t.Fatalf("managed SNI should resolve org-acme; got (%q, %v)", org, ok) + } + if org, ok := cp.flightOrgFromContext(flightCtxWithSNI("ghost.dw.us.postwh.com")); ok || org != "" { + t.Fatalf("unknown managed prefix must fail closed; got (%q, %v)", org, ok) + } + if _, ok := cp.flightOrgFromContext(flightCtxWithSNI("evil.example.com")); ok { + t.Fatalf("unmanaged hostname must fail closed") + } + if _, ok := cp.flightOrgFromContext(context.Background()); ok { + t.Fatalf("missing peer/SNI must fail closed") } } diff --git a/server/flightsqlingress/ingress.go b/server/flightsqlingress/ingress.go index a3169f25..38fda263 100644 --- a/server/flightsqlingress/ingress.go +++ b/server/flightsqlingress/ingress.go @@ -434,7 +434,7 @@ func (h *ControlPlaneFlightSQLHandler) authenticateBasicCredentials(ctx context. // CredentialValidator path. ok := false if sniAware, isSNIAware := h.validator.(SNIAwareCredentialValidator); isSNIAware { - ok = sniAware.ValidateCredentialsForSNI(sniFromContext(ctx), username, password) + ok = sniAware.ValidateCredentialsForSNI(SNIFromContext(ctx), username, password) } else { ok = h.validator.ValidateCredentials(username, password) } @@ -449,9 +449,11 @@ func (h *ControlPlaneFlightSQLHandler) authenticateBasicCredentials(ctx context. return username, nil } -// sniFromContext returns the TLS ServerName the client sent, or "" if the -// connection isn't TLS-terminated by this server (e.g. in unit tests). -func sniFromContext(ctx context.Context) string { +// SNIFromContext returns the TLS ServerName (SNI) the client sent, or "" if the +// connection isn't TLS-terminated by this server (e.g. in unit tests). Exported +// so callers that route by org reuse the exact same extraction the auth path +// uses — auth and routing must never disagree on a connection's hostname. +func SNIFromContext(ctx context.Context) string { pr, ok := peer.FromContext(ctx) if !ok || pr == nil { return ""