From 83a8f4357eb660f72bb720f899223b17b612fa76 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 10:29:47 +0200 Subject: [PATCH 01/25] Extract Lakebase target resolver into shared libs/lakebase/target Move the Postgres autoscaling and provisioned target-resolution helpers out of cmd/psql/ into a shared package so a second consumer (the new experimental postgres query command, in a follow-up commit) can reuse the same SDK shapes. cmd/psql keeps its interactive UX by wrapping the shared AutoSelect* helpers with errors.As fallbacks on AmbiguousError. No behavior change for cmd/psql; existing acceptance tests pass. Co-authored-by: Isaac --- cmd/psql/psql.go | 61 +++--------- cmd/psql/psql_autoscaling.go | 121 +++++++---------------- cmd/psql/psql_provisioned.go | 46 +++------ cmd/psql/psql_test.go | 83 ---------------- libs/lakebase/target/autoscaling.go | 122 +++++++++++++++++++++++ libs/lakebase/target/provisioned.go | 64 ++++++++++++ libs/lakebase/target/target.go | 145 ++++++++++++++++++++++++++++ libs/lakebase/target/target_test.go | 136 ++++++++++++++++++++++++++ 8 files changed, 523 insertions(+), 255 deletions(-) delete mode 100644 cmd/psql/psql_test.go create mode 100644 libs/lakebase/target/autoscaling.go create mode 100644 libs/lakebase/target/provisioned.go create mode 100644 libs/lakebase/target/target.go create mode 100644 libs/lakebase/target/target_test.go diff --git a/cmd/psql/psql.go b/cmd/psql/psql.go index e7f3a65f8b3..e5cfaff5cff 100644 --- a/cmd/psql/psql.go +++ b/cmd/psql/psql.go @@ -11,6 +11,7 @@ import ( "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdgroup" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/lakebase/target" libpsql "github.com/databricks/cli/libs/psql" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/database" @@ -86,9 +87,9 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ if argsLenAtDash < 0 { argsLenAtDash = len(args) } - target := "" + targetArg := "" if argsLenAtDash == 1 { - target = args[0] + targetArg = args[0] } else if argsLenAtDash > 1 { return errors.New("expected at most one positional argument for target") } @@ -109,16 +110,17 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ } // Positional argument takes precedence - if target != "" { - if strings.HasPrefix(target, "projects/") { + if targetArg != "" { + if target.IsAutoscalingPath(targetArg) { if provisionedFlag { return errors.New("cannot use --provisioned flag with an autoscaling resource path") } - projectID, branchID, endpointID, err := parseResourcePath(target) + spec, err := target.ParseAutoscalingPath(targetArg) if err != nil { return err } + projectID, branchID, endpointID := spec.ProjectID, spec.BranchID, spec.EndpointID // Check for conflicts between path and flags if projectFlag != "" && projectFlag != projectID { @@ -149,7 +151,7 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ if autoscalingFlag { return errors.New("cannot use --autoscaling flag with a provisioned instance name") } - return connectProvisioned(ctx, target, retryConfig, extraArgs) + return connectProvisioned(ctx, targetArg, retryConfig, extraArgs) } // No positional argument - use flags only @@ -197,45 +199,6 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ return cmd } -// parseResourcePath extracts project, branch, and endpoint IDs from a resource path. -// Returns an error for malformed paths. -func parseResourcePath(input string) (project, branch, endpoint string, err error) { - parts := strings.Split(input, "/") - - // Must start with projects/{project_id} - if len(parts) < 2 || parts[0] != "projects" { - return "", "", "", fmt.Errorf("invalid resource path: %s", input) - } - if parts[1] == "" { - return "", "", "", errors.New("invalid resource path: missing project ID") - } - project = parts[1] - - // Optional: branches/{branch_id} - if len(parts) > 2 { - if len(parts) < 4 || parts[2] != "branches" { - return "", "", "", errors.New("invalid resource path: expected 'branches' after project") - } - if parts[3] == "" { - return "", "", "", errors.New("invalid resource path: missing branch ID") - } - branch = parts[3] - } - - // Optional: endpoints/{endpoint_id} - if len(parts) > 4 { - if len(parts) < 6 || parts[4] != "endpoints" { - return "", "", "", errors.New("invalid resource path: expected 'endpoints' after branch") - } - if parts[5] == "" { - return "", "", "", errors.New("invalid resource path: missing endpoint ID") - } - endpoint = parts[5] - } - - return project, branch, endpoint, nil -} - // listAllDatabases fetches all database instances and projects in parallel. // Errors are silently ignored; callers should check for empty results. func listAllDatabases(ctx context.Context, w *databricks.WorkspaceClient) ([]database.DatabaseInstance, []postgres.Project) { @@ -248,12 +211,12 @@ func listAllDatabases(ctx context.Context, w *databricks.WorkspaceClient) ([]dat projectsCh := make(chan result[postgres.Project], 1) go func() { - instances, err := w.Database.ListDatabaseInstancesAll(ctx, database.ListDatabaseInstancesRequest{}) + instances, err := target.ListProvisionedInstances(ctx, w) instancesCh <- result[database.DatabaseInstance]{instances, err} }() go func() { - projects, err := w.Postgres.ListProjectsAll(ctx, postgres.ListProjectsRequest{}) + projects, err := target.ListProjects(ctx, w) projectsCh <- result[postgres.Project]{projects, err} }() @@ -294,7 +257,7 @@ func showSelectionAndConnect(ctx context.Context, retryConfig libpsql.RetryConfi }) } for _, proj := range projects { - displayName := extractIDFromName(proj.Name, "projects") + displayName := target.ExtractID(proj.Name, target.PathSegmentProjects) if proj.Status != nil && proj.Status.DisplayName != "" { displayName = proj.Status.DisplayName } @@ -315,7 +278,7 @@ func showSelectionAndConnect(ctx context.Context, retryConfig libpsql.RetryConfi } if after, ok := strings.CutPrefix(selected, "autoscaling:"); ok { projectName := after - projectID := extractIDFromName(projectName, "projects") + projectID := target.ExtractID(projectName, target.PathSegmentProjects) return connectAutoscaling(ctx, projectID, "", "", retryConfig, extraArgs) } diff --git a/cmd/psql/psql_autoscaling.go b/cmd/psql/psql_autoscaling.go index 00c555e4c12..4273dad3b50 100644 --- a/cmd/psql/psql_autoscaling.go +++ b/cmd/psql/psql_autoscaling.go @@ -4,10 +4,10 @@ import ( "context" "errors" "fmt" - "strings" "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/lakebase/target" libpsql "github.com/databricks/cli/libs/psql" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/postgres" @@ -16,18 +16,6 @@ import ( // autoscalingDefaultDatabase is the default database for Lakebase Autoscaling projects. const autoscalingDefaultDatabase = "databricks_postgres" -// extractIDFromName extracts the ID component from a resource name. -// For example, extractIDFromName("projects/foo/branches/bar", "branches") returns "bar". -func extractIDFromName(name, component string) string { - parts := strings.Split(name, "/") - for i := range len(parts) - 1 { - if parts[i] == component { - return parts[i+1] - } - } - return name -} - // connectAutoscaling connects to a Lakebase Autoscaling endpoint. func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID string, retryConfig libpsql.RetryConfig, extraArgs []string) error { w := cmdctx.WorkspaceClient(ctx) @@ -50,11 +38,9 @@ func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID str return errors.New("endpoint host information is not available") } - cred, err := w.Postgres.GenerateDatabaseCredential(ctx, postgres.GenerateDatabaseCredentialRequest{ - Endpoint: endpoint.Name, - }) + token, err := target.AutoscalingCredential(ctx, w, endpoint.Name) if err != nil { - return fmt.Errorf("failed to get database credentials: %w", err) + return err } var endpointType string @@ -83,7 +69,7 @@ func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID str return libpsql.Connect(ctx, libpsql.ConnectOptions{ Host: endpoint.Status.Hosts.Host, Username: user.UserName, - Password: cred.Token, + Password: token, DefaultDatabase: autoscalingDefaultDatabase, ExtraArgs: extraArgs, }, retryConfig) @@ -102,7 +88,7 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project } // Get project to display its name - project, err := w.Postgres.GetProject(ctx, postgres.GetProjectRequest{Name: "projects/" + projectID}) + project, err := target.GetProject(ctx, w, projectID) if err != nil { return nil, fmt.Errorf("failed to get project: %w", err) } @@ -136,7 +122,7 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project } // Get endpoint to validate and return it - endpoint, err := w.Postgres.GetEndpoint(ctx, postgres.GetEndpointRequest{Name: branch.Name + "/endpoints/" + endpointID}) + endpoint, err := target.GetEndpoint(ctx, w, projectID, branchID, endpointID) if err != nil { return nil, fmt.Errorf("failed to get endpoint: %w", err) } @@ -145,38 +131,31 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project return endpoint, nil } +// selectAmbiguous prompts the user to pick one of the choices in an +// AmbiguousError. Caller is expected to have logged a header (e.g. via the +// spinner) before invoking. Used to keep psql's interactive UX while letting +// the shared lib do the actual list+filter work. +func selectAmbiguous(ctx context.Context, amb *target.AmbiguousError, prompt string) (string, error) { + items := make([]cmdio.Tuple, 0, len(amb.Choices)) + for _, c := range amb.Choices { + items = append(items, cmdio.Tuple{Name: c.DisplayName, Id: c.ID}) + } + return cmdio.SelectOrdered(ctx, items, prompt) +} + // selectProjectID auto-selects if there's only one project, otherwise prompts user to select. // Returns the project ID (not the full project object). func selectProjectID(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading projects...") - projects, err := w.Postgres.ListProjectsAll(ctx, postgres.ListProjectsRequest{}) + id, err := target.AutoSelectProject(ctx, w) sp.Close() - if err != nil { - return "", err - } - - if len(projects) == 0 { - return "", errors.New("no Lakebase Autoscaling projects found in workspace") - } - // Auto-select if there's only one project - if len(projects) == 1 { - return extractIDFromName(projects[0].Name, "projects"), nil - } - - // Multiple projects, prompt user to select - var items []cmdio.Tuple - for _, project := range projects { - projectID := extractIDFromName(project.Name, "projects") - displayName := projectID - if project.Status != nil && project.Status.DisplayName != "" { - displayName = project.Status.DisplayName - } - items = append(items, cmdio.Tuple{Name: displayName, Id: projectID}) + var amb *target.AmbiguousError + if !errors.As(err, &amb) { + return id, err } - - return cmdio.SelectOrdered(ctx, items, "Select project") + return selectAmbiguous(ctx, amb, "Select project") } // selectBranchID auto-selects if there's only one branch, otherwise prompts user to select. @@ -184,31 +163,14 @@ func selectProjectID(ctx context.Context, w *databricks.WorkspaceClient) (string func selectBranchID(ctx context.Context, w *databricks.WorkspaceClient, projectName string) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading branches...") - branches, err := w.Postgres.ListBranchesAll(ctx, postgres.ListBranchesRequest{ - Parent: projectName, - }) + id, err := target.AutoSelectBranch(ctx, w, projectName) sp.Close() - if err != nil { - return "", err - } - - if len(branches) == 0 { - return "", errors.New("no branches found in project") - } - - // Auto-select if there's only one branch - if len(branches) == 1 { - return extractIDFromName(branches[0].Name, "branches"), nil - } - // Multiple branches, prompt user to select - var items []cmdio.Tuple - for _, branch := range branches { - branchID := extractIDFromName(branch.Name, "branches") - items = append(items, cmdio.Tuple{Name: branchID, Id: branchID}) + var amb *target.AmbiguousError + if !errors.As(err, &amb) { + return id, err } - - return cmdio.SelectOrdered(ctx, items, "Select branch") + return selectAmbiguous(ctx, amb, "Select branch") } // selectEndpointID auto-selects if there's only one endpoint, otherwise prompts user to select. @@ -216,29 +178,12 @@ func selectBranchID(ctx context.Context, w *databricks.WorkspaceClient, projectN func selectEndpointID(ctx context.Context, w *databricks.WorkspaceClient, branchName string) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading endpoints...") - endpoints, err := w.Postgres.ListEndpointsAll(ctx, postgres.ListEndpointsRequest{ - Parent: branchName, - }) + id, err := target.AutoSelectEndpoint(ctx, w, branchName) sp.Close() - if err != nil { - return "", err - } - - if len(endpoints) == 0 { - return "", errors.New("no endpoints found in branch") - } - // Auto-select if there's only one endpoint - if len(endpoints) == 1 { - return extractIDFromName(endpoints[0].Name, "endpoints"), nil + var amb *target.AmbiguousError + if !errors.As(err, &amb) { + return id, err } - - // Multiple endpoints, prompt user to select - var items []cmdio.Tuple - for _, endpoint := range endpoints { - endpointID := extractIDFromName(endpoint.Name, "endpoints") - items = append(items, cmdio.Tuple{Name: endpointID, Id: endpointID}) - } - - return cmdio.SelectOrdered(ctx, items, "Select endpoint") + return selectAmbiguous(ctx, amb, "Select endpoint") } diff --git a/cmd/psql/psql_provisioned.go b/cmd/psql/psql_provisioned.go index 88ca1bb9181..9ea88def5ce 100644 --- a/cmd/psql/psql_provisioned.go +++ b/cmd/psql/psql_provisioned.go @@ -7,10 +7,10 @@ import ( "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/lakebase/target" libpsql "github.com/databricks/cli/libs/psql" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/database" - "github.com/google/uuid" ) // provisionedDefaultDatabase is the default database for Lakebase Provisioned instances. @@ -39,12 +39,9 @@ func connectProvisioned(ctx context.Context, instanceName string, retryConfig li return errors.New("database instance is not ready for accepting connections") } - cred, err := w.Database.GenerateDatabaseCredential(ctx, database.GenerateDatabaseCredentialRequest{ - InstanceNames: []string{instance.Name}, - RequestId: uuid.NewString(), - }) + token, err := target.ProvisionedCredential(ctx, w, instance.Name) if err != nil { - return fmt.Errorf("failed to get database credentials: %w", err) + return err } cmdio.LogString(ctx, "Connecting to database instance...") @@ -52,7 +49,7 @@ func connectProvisioned(ctx context.Context, instanceName string, retryConfig li return libpsql.Connect(ctx, libpsql.ConnectOptions{ Host: instance.ReadWriteDns, Username: user.UserName, - Password: cred.Token, + Password: token, DefaultDatabase: provisionedDefaultDatabase, ExtraArgs: extraArgs, }, retryConfig) @@ -61,7 +58,6 @@ func connectProvisioned(ctx context.Context, instanceName string, retryConfig li // resolveInstance resolves an instance name to a full instance object. // If instanceName is empty, prompts the user to select one. func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanceName string) (*database.DatabaseInstance, error) { - // If instance not specified, select one if instanceName == "" { var err error instanceName, err = selectInstanceID(ctx, w) @@ -70,15 +66,9 @@ func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanc } } - instance, err := w.Database.GetDatabaseInstance(ctx, database.GetDatabaseInstanceRequest{ - Name: instanceName, - }) + instance, err := target.GetProvisioned(ctx, w, instanceName) if err != nil { - return nil, fmt.Errorf("failed to get database instance: %w", err) - } - // Ensure Name is set (API response may not include it) - if instance.Name == "" { - instance.Name = instanceName + return nil, err } cmdio.LogString(ctx, "Instance: "+instance.Name) @@ -90,26 +80,12 @@ func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanc func selectInstanceID(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading instances...") - instances, err := w.Database.ListDatabaseInstancesAll(ctx, database.ListDatabaseInstancesRequest{}) + id, err := target.AutoSelectProvisioned(ctx, w) sp.Close() - if err != nil { - return "", err - } - if len(instances) == 0 { - return "", errors.New("no Lakebase Provisioned instances found in workspace") + var amb *target.AmbiguousError + if !errors.As(err, &amb) { + return id, err } - - // Auto-select if there's only one instance - if len(instances) == 1 { - return instances[0].Name, nil - } - - // Multiple instances, prompt user to select - var items []cmdio.Tuple - for _, inst := range instances { - items = append(items, cmdio.Tuple{Name: inst.Name, Id: inst.Name}) - } - - return cmdio.SelectOrdered(ctx, items, "Select instance") + return selectAmbiguous(ctx, amb, "Select instance") } diff --git a/cmd/psql/psql_test.go b/cmd/psql/psql_test.go deleted file mode 100644 index fc8a7e53cba..00000000000 --- a/cmd/psql/psql_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package psql - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestParseResourcePath(t *testing.T) { - tests := []struct { - name string - input string - project string - branch string - endpoint string - wantErr string - }{ - { - name: "project only", - input: "projects/my-project", - project: "my-project", - }, - { - name: "project and branch", - input: "projects/my-project/branches/main", - project: "my-project", - branch: "main", - }, - { - name: "full path", - input: "projects/my-project/branches/main/endpoints/primary", - project: "my-project", - branch: "main", - endpoint: "primary", - }, - { - name: "missing project ID", - input: "projects/", - wantErr: "missing project ID", - }, - { - name: "missing branch ID", - input: "projects/my-project/branches/", - wantErr: "missing branch ID", - }, - { - name: "missing endpoint ID", - input: "projects/my-project/branches/main/endpoints/", - wantErr: "missing endpoint ID", - }, - { - name: "invalid segment after project", - input: "projects/my-project/invalid/foo", - wantErr: "expected 'branches' after project", - }, - { - name: "invalid segment after branch", - input: "projects/my-project/branches/main/invalid/foo", - wantErr: "expected 'endpoints' after branch", - }, - { - name: "not a projects path", - input: "something/else", - wantErr: "invalid resource path", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - project, branch, endpoint, err := parseResourcePath(tc.input) - if tc.wantErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.wantErr) - return - } - require.NoError(t, err) - assert.Equal(t, tc.project, project) - assert.Equal(t, tc.branch, branch) - assert.Equal(t, tc.endpoint, endpoint) - }) - } -} diff --git a/libs/lakebase/target/autoscaling.go b/libs/lakebase/target/autoscaling.go new file mode 100644 index 00000000000..f1edef216d4 --- /dev/null +++ b/libs/lakebase/target/autoscaling.go @@ -0,0 +1,122 @@ +package target + +import ( + "context" + "errors" + "fmt" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/postgres" +) + +// ListProjects returns all autoscaling projects in the workspace. +func ListProjects(ctx context.Context, w *databricks.WorkspaceClient) ([]postgres.Project, error) { + return w.Postgres.ListProjectsAll(ctx, postgres.ListProjectsRequest{}) +} + +// ListBranches returns all branches under the given project. +// projectName is the SDK resource name like "projects/foo". +func ListBranches(ctx context.Context, w *databricks.WorkspaceClient, projectName string) ([]postgres.Branch, error) { + return w.Postgres.ListBranchesAll(ctx, postgres.ListBranchesRequest{Parent: projectName}) +} + +// ListEndpoints returns all endpoints under the given branch. +// branchName is the SDK resource name like "projects/foo/branches/bar". +func ListEndpoints(ctx context.Context, w *databricks.WorkspaceClient, branchName string) ([]postgres.Endpoint, error) { + return w.Postgres.ListEndpointsAll(ctx, postgres.ListEndpointsRequest{Parent: branchName}) +} + +// GetProject fetches a single project by ID. +func GetProject(ctx context.Context, w *databricks.WorkspaceClient, projectID string) (*postgres.Project, error) { + return w.Postgres.GetProject(ctx, postgres.GetProjectRequest{Name: PathSegmentProjects + "/" + projectID}) +} + +// GetEndpoint fetches a single endpoint by ID, given its parent IDs. +func GetEndpoint(ctx context.Context, w *databricks.WorkspaceClient, projectID, branchID, endpointID string) (*postgres.Endpoint, error) { + name := fmt.Sprintf("projects/%s/branches/%s/endpoints/%s", projectID, branchID, endpointID) + return w.Postgres.GetEndpoint(ctx, postgres.GetEndpointRequest{Name: name}) +} + +// AutoSelectProject returns the only project in the workspace, or an +// AmbiguousError carrying the choices if there are multiple. Returns a plain +// error if there are no projects. +func AutoSelectProject(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { + projects, err := ListProjects(ctx, w) + if err != nil { + return "", err + } + if len(projects) == 0 { + return "", errors.New("no Lakebase Autoscaling projects found in workspace") + } + if len(projects) == 1 { + return ExtractID(projects[0].Name, PathSegmentProjects), nil + } + + choices := make([]Choice, 0, len(projects)) + for _, p := range projects { + id := ExtractID(p.Name, PathSegmentProjects) + display := id + if p.Status != nil && p.Status.DisplayName != "" { + display = p.Status.DisplayName + } + choices = append(choices, Choice{ID: id, DisplayName: display}) + } + return "", &AmbiguousError{Kind: "project", FlagHint: "--project", Choices: choices} +} + +// AutoSelectBranch returns the only branch under projectName, or an +// AmbiguousError if there are multiple. +func AutoSelectBranch(ctx context.Context, w *databricks.WorkspaceClient, projectName string) (string, error) { + branches, err := ListBranches(ctx, w, projectName) + if err != nil { + return "", err + } + if len(branches) == 0 { + return "", errors.New("no branches found in project") + } + if len(branches) == 1 { + return ExtractID(branches[0].Name, pathSegmentBranches), nil + } + + choices := make([]Choice, 0, len(branches)) + for _, b := range branches { + id := ExtractID(b.Name, pathSegmentBranches) + choices = append(choices, Choice{ID: id, DisplayName: id}) + } + return "", &AmbiguousError{Kind: "branch", Parent: projectName, FlagHint: "--branch", Choices: choices} +} + +// AutoSelectEndpoint returns the only endpoint under branchName, or an +// AmbiguousError if there are multiple. +func AutoSelectEndpoint(ctx context.Context, w *databricks.WorkspaceClient, branchName string) (string, error) { + endpoints, err := ListEndpoints(ctx, w, branchName) + if err != nil { + return "", err + } + if len(endpoints) == 0 { + return "", errors.New("no endpoints found in branch") + } + if len(endpoints) == 1 { + return ExtractID(endpoints[0].Name, pathSegmentEndpoints), nil + } + + choices := make([]Choice, 0, len(endpoints)) + for _, e := range endpoints { + id := ExtractID(e.Name, pathSegmentEndpoints) + choices = append(choices, Choice{ID: id, DisplayName: id}) + } + return "", &AmbiguousError{Kind: "endpoint", Parent: branchName, FlagHint: "--endpoint", Choices: choices} +} + +// AutoscalingCredential issues a short-lived OAuth token that can be used to +// authenticate to the given autoscaling endpoint. endpointName is the SDK +// resource name (e.g. "projects/foo/branches/bar/endpoints/baz"). +func AutoscalingCredential(ctx context.Context, w *databricks.WorkspaceClient, endpointName string) (string, error) { + cred, err := w.Postgres.GenerateDatabaseCredential(ctx, postgres.GenerateDatabaseCredentialRequest{ + Endpoint: endpointName, + }) + if err != nil { + return "", fmt.Errorf("failed to get database credentials: %w", err) + } + return cred.Token, nil +} diff --git a/libs/lakebase/target/provisioned.go b/libs/lakebase/target/provisioned.go new file mode 100644 index 00000000000..773cc867ce0 --- /dev/null +++ b/libs/lakebase/target/provisioned.go @@ -0,0 +1,64 @@ +package target + +import ( + "context" + "errors" + "fmt" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/database" + "github.com/google/uuid" +) + +// ListProvisionedInstances returns all provisioned database instances in the workspace. +func ListProvisionedInstances(ctx context.Context, w *databricks.WorkspaceClient) ([]database.DatabaseInstance, error) { + return w.Database.ListDatabaseInstancesAll(ctx, database.ListDatabaseInstancesRequest{}) +} + +// GetProvisioned fetches a single provisioned instance by name. +// The Name field on the response can be empty; this function ensures it is +// populated from the input so downstream callers do not have to re-set it. +func GetProvisioned(ctx context.Context, w *databricks.WorkspaceClient, name string) (*database.DatabaseInstance, error) { + instance, err := w.Database.GetDatabaseInstance(ctx, database.GetDatabaseInstanceRequest{Name: name}) + if err != nil { + return nil, fmt.Errorf("failed to get database instance: %w", err) + } + if instance.Name == "" { + instance.Name = name + } + return instance, nil +} + +// AutoSelectProvisioned returns the only provisioned instance in the workspace, +// or an AmbiguousError if there are multiple. Returns a plain error if none. +func AutoSelectProvisioned(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { + instances, err := ListProvisionedInstances(ctx, w) + if err != nil { + return "", err + } + if len(instances) == 0 { + return "", errors.New("no Lakebase Provisioned instances found in workspace") + } + if len(instances) == 1 { + return instances[0].Name, nil + } + + choices := make([]Choice, 0, len(instances)) + for _, inst := range instances { + choices = append(choices, Choice{ID: inst.Name, DisplayName: inst.Name}) + } + return "", &AmbiguousError{Kind: "instance", FlagHint: "--target", Choices: choices} +} + +// ProvisionedCredential issues a short-lived OAuth token for the provisioned +// instance with the given name. +func ProvisionedCredential(ctx context.Context, w *databricks.WorkspaceClient, instanceName string) (string, error) { + cred, err := w.Database.GenerateDatabaseCredential(ctx, database.GenerateDatabaseCredentialRequest{ + InstanceNames: []string{instanceName}, + RequestId: uuid.NewString(), + }) + if err != nil { + return "", fmt.Errorf("failed to get database credentials: %w", err) + } + return cred.Token, nil +} diff --git a/libs/lakebase/target/target.go b/libs/lakebase/target/target.go new file mode 100644 index 00000000000..d02c95903ce --- /dev/null +++ b/libs/lakebase/target/target.go @@ -0,0 +1,145 @@ +// Package target resolves Lakebase Postgres targets (provisioned instances and +// autoscaling endpoints) into the host, credential, and SDK metadata that +// callers need to open a connection. It is shared by `cmd/psql` and the +// `experimental postgres query` command so that both speak the same SDK. +package target + +import ( + "errors" + "fmt" + "strings" +) + +const ( + // PathSegmentProjects is the leading path segment that identifies an + // autoscaling resource path. Provisioned instance names never start with it. + PathSegmentProjects = "projects" + pathSegmentBranches = "branches" + pathSegmentEndpoints = "endpoints" +) + +// AutoscalingSpec is a partial or full specification for an autoscaling endpoint. +// Empty fields signal "auto-select if exactly one exists, otherwise error". +type AutoscalingSpec struct { + ProjectID string + BranchID string + EndpointID string +} + +// Choice is a single candidate returned alongside an AmbiguousError so callers +// can either render the list to the user or prompt interactively. +type Choice struct { + ID string + DisplayName string +} + +// AmbiguousError is returned by AutoSelect* helpers when the SDK returns more +// than one candidate and the caller did not specify which one to pick. +// +// Callers that have a TTY (e.g. `databricks psql`) can use errors.As to detect +// this and prompt interactively. Callers that are non-interactive (e.g. the +// scriptable `postgres query` command) propagate it as a plain error: the +// formatted message already enumerates the choices. +type AmbiguousError struct { + // Kind identifies what was ambiguous: "project", "branch", or "endpoint". + Kind string + // Parent is the SDK resource name that contained the ambiguity (e.g. + // "projects/foo" when listing branches), or empty when listing projects. + Parent string + // FlagHint is the flag a user would set to disambiguate (e.g. "--branch"). + FlagHint string + // Choices enumerates the candidates returned by the SDK. + Choices []Choice +} + +func (e *AmbiguousError) Error() string { + plural := map[string]string{ + "project": "projects", + "branch": "branches", + "endpoint": "endpoints", + "instance": "instances", + }[e.Kind] + if plural == "" { + plural = e.Kind + } + + var sb strings.Builder + if e.Parent == "" { + fmt.Fprintf(&sb, "multiple %s found; specify %s:", plural, e.FlagHint) + } else { + fmt.Fprintf(&sb, "multiple %s found in %s; specify %s:", plural, e.Parent, e.FlagHint) + } + for _, c := range e.Choices { + sb.WriteString("\n - ") + sb.WriteString(c.ID) + if c.DisplayName != "" && c.DisplayName != c.ID { + fmt.Fprintf(&sb, " (%s)", c.DisplayName) + } + } + return sb.String() +} + +// ParseAutoscalingPath extracts project, branch, and endpoint IDs from a +// resource path. Accepts partial paths: +// +// projects/foo +// projects/foo/branches/bar +// projects/foo/branches/bar/endpoints/baz +// +// Returns an error if the path is malformed or does not start with "projects/". +func ParseAutoscalingPath(input string) (AutoscalingSpec, error) { + parts := strings.Split(input, "/") + + if len(parts) < 2 || parts[0] != PathSegmentProjects { + return AutoscalingSpec{}, fmt.Errorf("invalid resource path: %s", input) + } + if parts[1] == "" { + return AutoscalingSpec{}, errors.New("invalid resource path: missing project ID") + } + spec := AutoscalingSpec{ProjectID: parts[1]} + + if len(parts) > 2 { + if len(parts) < 4 || parts[2] != pathSegmentBranches { + return AutoscalingSpec{}, errors.New("invalid resource path: expected 'branches' after project") + } + if parts[3] == "" { + return AutoscalingSpec{}, errors.New("invalid resource path: missing branch ID") + } + spec.BranchID = parts[3] + } + + if len(parts) > 4 { + if len(parts) < 6 || parts[4] != pathSegmentEndpoints { + return AutoscalingSpec{}, errors.New("invalid resource path: expected 'endpoints' after branch") + } + if parts[5] == "" { + return AutoscalingSpec{}, errors.New("invalid resource path: missing endpoint ID") + } + spec.EndpointID = parts[5] + } + + if len(parts) > 6 { + return AutoscalingSpec{}, fmt.Errorf("invalid resource path: trailing components after endpoint: %s", input) + } + + return spec, nil +} + +// ExtractID returns the value following component in a resource name. +// ExtractID("projects/foo/branches/bar", "branches") returns "bar". +// Returns the original name unchanged if component is not found. +func ExtractID(name, component string) string { + parts := strings.Split(name, "/") + for i := range len(parts) - 1 { + if parts[i] == component { + return parts[i+1] + } + } + return name +} + +// IsAutoscalingPath reports whether s is an autoscaling resource path +// (i.e. starts with "projects/"). Provisioned instance names never do. +func IsAutoscalingPath(s string) bool { + return strings.HasPrefix(s, PathSegmentProjects+"/") +} diff --git a/libs/lakebase/target/target_test.go b/libs/lakebase/target/target_test.go new file mode 100644 index 00000000000..4b4a763c122 --- /dev/null +++ b/libs/lakebase/target/target_test.go @@ -0,0 +1,136 @@ +package target + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseAutoscalingPath(t *testing.T) { + tests := []struct { + name string + input string + want AutoscalingSpec + wantErr string + }{ + { + name: "project only", + input: "projects/my-project", + want: AutoscalingSpec{ProjectID: "my-project"}, + }, + { + name: "project and branch", + input: "projects/my-project/branches/main", + want: AutoscalingSpec{ProjectID: "my-project", BranchID: "main"}, + }, + { + name: "full path", + input: "projects/my-project/branches/main/endpoints/primary", + want: AutoscalingSpec{ProjectID: "my-project", BranchID: "main", EndpointID: "primary"}, + }, + { + name: "missing project ID", + input: "projects/", + wantErr: "missing project ID", + }, + { + name: "missing branch ID", + input: "projects/my-project/branches/", + wantErr: "missing branch ID", + }, + { + name: "missing endpoint ID", + input: "projects/my-project/branches/main/endpoints/", + wantErr: "missing endpoint ID", + }, + { + name: "invalid segment after project", + input: "projects/my-project/invalid/foo", + wantErr: "expected 'branches' after project", + }, + { + name: "invalid segment after branch", + input: "projects/my-project/branches/main/invalid/foo", + wantErr: "expected 'endpoints' after branch", + }, + { + name: "not a projects path", + input: "something/else", + wantErr: "invalid resource path", + }, + { + name: "trailing components after endpoint", + input: "projects/foo/branches/bar/endpoints/baz/extra", + wantErr: "trailing components after endpoint", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := ParseAutoscalingPath(tc.input) + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestExtractID(t *testing.T) { + assert.Equal(t, "bar", ExtractID("projects/foo/branches/bar", "branches")) + assert.Equal(t, "foo", ExtractID("projects/foo", "projects")) + assert.Equal(t, "baz", ExtractID("projects/foo/branches/bar/endpoints/baz", "endpoints")) + assert.Equal(t, "no-component", ExtractID("no-component", "missing")) +} + +func TestIsAutoscalingPath(t *testing.T) { + assert.True(t, IsAutoscalingPath("projects/foo")) + assert.True(t, IsAutoscalingPath("projects/foo/branches/bar")) + assert.False(t, IsAutoscalingPath("my-instance")) + assert.False(t, IsAutoscalingPath("")) + assert.False(t, IsAutoscalingPath("projects")) +} + +func TestAmbiguousErrorMessage(t *testing.T) { + t.Run("with parent", func(t *testing.T) { + err := &AmbiguousError{ + Kind: "branch", + Parent: "projects/foo", + FlagHint: "--branch", + Choices: []Choice{ + {ID: "main", DisplayName: "main"}, + {ID: "feature-x", DisplayName: "feature-x"}, + }, + } + assert.Equal(t, + "multiple branches found in projects/foo; specify --branch:\n - main\n - feature-x", + err.Error(), + ) + }) + + t.Run("without parent", func(t *testing.T) { + err := &AmbiguousError{ + Kind: "project", + FlagHint: "--project", + Choices: []Choice{ + {ID: "alpha", DisplayName: "Alpha Project"}, + {ID: "beta", DisplayName: "beta"}, + }, + } + assert.Equal(t, + "multiple projects found; specify --project:\n - alpha (Alpha Project)\n - beta", + err.Error(), + ) + }) + + t.Run("errors.As", func(t *testing.T) { + var amb *AmbiguousError + err := error(&AmbiguousError{Kind: "endpoint", FlagHint: "--endpoint"}) + assert.ErrorAs(t, err, &amb) + assert.Equal(t, "endpoint", amb.Kind) + }) +} From bfb632090cd6e4e5ec75e69720de781a2eba2be8 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 10:30:02 +0200 Subject: [PATCH 02/25] Add experimental postgres query command (autoscaling, text output) Scaffolds 'databricks experimental postgres query', a scriptable SQL runner against a Lakebase Postgres autoscaling endpoint that does not require a system psql binary. This PR ships the smallest useful slice: - Single positional SQL statement. - --target (autoscaling resource path), --project, --branch, --endpoint targeting forms; provisioned-shaped targets return a pointer at 'databricks psql' for now. - Connect retry on idle/waking endpoints (08xxx SQLSTATE family, dial errors). - Text output (static table for rows-producing statements, command tag for command-only). Provisioned support, JSON/CSV streaming output, multi-statement input, cancellation, and integration tests come in follow-up PRs. Driver: github.com/jackc/pgx/v5 v5.9.1 (MIT). Already a direct dep of the universe monorepo's Lakebase services; aligning here keeps the SDK surface consistent. Co-authored-by: Isaac --- NEXT_CHANGELOG.md | 2 + NOTICE | 4 + .../query/ambiguous-targeting/out.test.toml | 8 + .../query/ambiguous-targeting/output.txt | 18 ++ .../postgres/query/ambiguous-targeting/script | 8 + .../query/ambiguous-targeting/test.toml | 62 +++++++ .../query/argument-errors/out.test.toml | 8 + .../postgres/query/argument-errors/output.txt | 40 ++++ .../postgres/query/argument-errors/script | 29 +++ .../postgres/query/argument-errors/test.toml | 3 + cmd/experimental/experimental.go | 2 + experimental/postgres/cmd/cmd.go | 25 +++ experimental/postgres/cmd/connect.go | 147 +++++++++++++++ experimental/postgres/cmd/connect_test.go | 149 +++++++++++++++ experimental/postgres/cmd/execute.go | 62 +++++++ experimental/postgres/cmd/query.go | 133 ++++++++++++++ experimental/postgres/cmd/render.go | 74 ++++++++ experimental/postgres/cmd/render_test.go | 67 +++++++ experimental/postgres/cmd/targeting.go | 173 ++++++++++++++++++ experimental/postgres/cmd/targeting_test.go | 81 ++++++++ go.mod | 3 + go.sum | 10 + 22 files changed, 1108 insertions(+) create mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml create mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt create mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script create mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml create mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml create mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/output.txt create mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/script create mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/test.toml create mode 100644 experimental/postgres/cmd/cmd.go create mode 100644 experimental/postgres/cmd/connect.go create mode 100644 experimental/postgres/cmd/connect_test.go create mode 100644 experimental/postgres/cmd/execute.go create mode 100644 experimental/postgres/cmd/query.go create mode 100644 experimental/postgres/cmd/render.go create mode 100644 experimental/postgres/cmd/render_test.go create mode 100644 experimental/postgres/cmd/targeting.go create mode 100644 experimental/postgres/cmd/targeting_test.go diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 00152d550ea..be66fe3964b 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -7,3 +7,5 @@ ### Bundles ### Dependency updates + +* Added `github.com/jackc/pgx/v5` v5.9.1 (MIT) as a new dependency. Used by an experimental Postgres command added in this release; the package is dormant for users who do not invoke that command. diff --git a/NOTICE b/NOTICE index 1e286df6f91..7077be46928 100644 --- a/NOTICE +++ b/NOTICE @@ -127,6 +127,10 @@ google/jsonschema-go - https://github.com/google/jsonschema-go Copyright 2025 Google LLC License - https://github.com/google/jsonschema-go/blob/main/LICENSE +jackc/pgx - https://github.com/jackc/pgx +Copyright (c) 2013-2021 Jack Christensen +License - https://github.com/jackc/pgx/blob/master/LICENSE + charmbracelet/bubbles - https://github.com/charmbracelet/bubbles Copyright (c) 2020-2025 Charmbracelet, Inc License - https://github.com/charmbracelet/bubbles/blob/master/LICENSE diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml new file mode 100644 index 00000000000..40bb0d10471 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml @@ -0,0 +1,8 @@ +Local = true +Cloud = false + +[GOOS] + windows = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt new file mode 100644 index 00000000000..e95a7b3613d --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt @@ -0,0 +1,18 @@ + +=== Project with multiple branches and no --branch should error with choices: +>>> musterr [CLI] experimental postgres query --project foo SELECT 1 +Error: multiple branches found in projects/foo; specify --branch: + - main + - dev + +=== Project with multiple endpoints in only branch should error with choices: +>>> musterr [CLI] experimental postgres query --project bar SELECT 1 +Error: multiple endpoints found in projects/bar/branches/only; specify --endpoint: + - read-write + - read-only + +=== Partial path with multiple branches should error with choices: +>>> musterr [CLI] experimental postgres query --target projects/foo SELECT 1 +Error: multiple branches found in projects/foo; specify --branch: + - main + - dev diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script new file mode 100644 index 00000000000..6143fd96f02 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script @@ -0,0 +1,8 @@ +title "Project with multiple branches and no --branch should error with choices:" +trace musterr $CLI experimental postgres query --project foo "SELECT 1" + +title "Project with multiple endpoints in only branch should error with choices:" +trace musterr $CLI experimental postgres query --project bar "SELECT 1" + +title "Partial path with multiple branches should error with choices:" +trace musterr $CLI experimental postgres query --target projects/foo "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml new file mode 100644 index 00000000000..2a61e7e8e25 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml @@ -0,0 +1,62 @@ +GOOS.windows = false + +[[Server]] +Pattern = "GET /api/2.0/postgres/projects" +Response.Body = ''' +{ + "projects": [ + {"name": "projects/alpha", "status": {"display_name": "Alpha"}}, + {"name": "projects/beta", "status": {"display_name": "Beta"}} + ] +} +''' + +[[Server]] +Pattern = "GET /api/2.0/postgres/projects/foo" +Response.Body = ''' +{ + "name": "projects/foo", + "status": {"display_name": "Foo Project"} +} +''' + +[[Server]] +Pattern = "GET /api/2.0/postgres/projects/foo/branches" +Response.Body = ''' +{ + "branches": [ + {"name": "projects/foo/branches/main"}, + {"name": "projects/foo/branches/dev"} + ] +} +''' + +[[Server]] +Pattern = "GET /api/2.0/postgres/projects/bar" +Response.Body = ''' +{ + "name": "projects/bar", + "status": {"display_name": "Bar Project"} +} +''' + +[[Server]] +Pattern = "GET /api/2.0/postgres/projects/bar/branches" +Response.Body = ''' +{ + "branches": [ + {"name": "projects/bar/branches/only"} + ] +} +''' + +[[Server]] +Pattern = "GET /api/2.0/postgres/projects/bar/branches/only/endpoints" +Response.Body = ''' +{ + "endpoints": [ + {"name": "projects/bar/branches/only/endpoints/read-write"}, + {"name": "projects/bar/branches/only/endpoints/read-only"} + ] +} +''' diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml b/acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml new file mode 100644 index 00000000000..40bb0d10471 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml @@ -0,0 +1,8 @@ +Local = true +Cloud = false + +[GOOS] + windows = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt new file mode 100644 index 00000000000..59ddbfedc6e --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt @@ -0,0 +1,40 @@ + +=== No SQL argument should error: +>>> musterr [CLI] experimental postgres query --target projects/foo +Error: accepts 1 arg(s), received 0 + +=== Empty SQL should error: +>>> musterr [CLI] experimental postgres query --target projects/foo +Error: no SQL provided + +=== Neither targeting form should error: +>>> musterr [CLI] experimental postgres query SELECT 1 +Error: must specify --target or --project + +=== Both --target and --project should error: +>>> musterr [CLI] experimental postgres query --target projects/foo --project foo SELECT 1 +Error: if any flags in the group [target project] are set none of the others can be; [project target] were all set + +=== Both --target and --branch should error: +>>> musterr [CLI] experimental postgres query --target projects/foo --branch main SELECT 1 +Error: if any flags in the group [target branch] are set none of the others can be; [branch target] were all set + +=== Branch without project should error: +>>> musterr [CLI] experimental postgres query --branch main SELECT 1 +Error: --project is required when using --branch or --endpoint + +=== Endpoint without project should error: +>>> musterr [CLI] experimental postgres query --endpoint primary SELECT 1 +Error: --project is required when using --branch or --endpoint + +=== Provisioned-shaped target should error pointing at psql: +>>> musterr [CLI] experimental postgres query --target my-instance SELECT 1 +Error: provisioned instances are not yet supported by this experimental command; use 'databricks psql ' for now + +=== Malformed autoscaling path should error: +>>> musterr [CLI] experimental postgres query --target projects/ SELECT 1 +Error: invalid resource path: missing project ID + +=== Trailing components after endpoint should error: +>>> musterr [CLI] experimental postgres query --target projects/foo/branches/bar/endpoints/baz/extra SELECT 1 +Error: invalid resource path: trailing components after endpoint: projects/foo/branches/bar/endpoints/baz/extra diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/script b/acceptance/cmd/experimental/postgres/query/argument-errors/script new file mode 100644 index 00000000000..5874c843a03 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/script @@ -0,0 +1,29 @@ +title "No SQL argument should error:" +trace musterr $CLI experimental postgres query --target projects/foo + +title "Empty SQL should error:" +trace musterr $CLI experimental postgres query --target projects/foo " " + +title "Neither targeting form should error:" +trace musterr $CLI experimental postgres query "SELECT 1" + +title "Both --target and --project should error:" +trace musterr $CLI experimental postgres query --target projects/foo --project foo "SELECT 1" + +title "Both --target and --branch should error:" +trace musterr $CLI experimental postgres query --target projects/foo --branch main "SELECT 1" + +title "Branch without project should error:" +trace musterr $CLI experimental postgres query --branch main "SELECT 1" + +title "Endpoint without project should error:" +trace musterr $CLI experimental postgres query --endpoint primary "SELECT 1" + +title "Provisioned-shaped target should error pointing at psql:" +trace musterr $CLI experimental postgres query --target my-instance "SELECT 1" + +title "Malformed autoscaling path should error:" +trace musterr $CLI experimental postgres query --target projects/ "SELECT 1" + +title "Trailing components after endpoint should error:" +trace musterr $CLI experimental postgres query --target projects/foo/branches/bar/endpoints/baz/extra "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/test.toml b/acceptance/cmd/experimental/postgres/query/argument-errors/test.toml new file mode 100644 index 00000000000..3371f08de12 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/test.toml @@ -0,0 +1,3 @@ +# Argument validation runs before any SDK call. No mocked HTTP responses are +# needed; CLI either errors at flag-parse time or at our own validate function. +GOOS.windows = false diff --git a/cmd/experimental/experimental.go b/cmd/experimental/experimental.go index 36ad8765898..52c6bac79b0 100644 --- a/cmd/experimental/experimental.go +++ b/cmd/experimental/experimental.go @@ -2,6 +2,7 @@ package experimental import ( aitoolscmd "github.com/databricks/cli/experimental/aitools/cmd" + postgrescmd "github.com/databricks/cli/experimental/postgres/cmd" "github.com/spf13/cobra" ) @@ -21,6 +22,7 @@ development. They may change or be removed in future versions without notice.`, } cmd.AddCommand(aitoolscmd.NewAitoolsCmd()) + cmd.AddCommand(postgrescmd.New()) cmd.AddCommand(newWorkspaceOpenCommand()) return cmd diff --git a/experimental/postgres/cmd/cmd.go b/experimental/postgres/cmd/cmd.go new file mode 100644 index 00000000000..8db7b46be86 --- /dev/null +++ b/experimental/postgres/cmd/cmd.go @@ -0,0 +1,25 @@ +// Package postgrescmd registers the `databricks experimental postgres ...` +// command tree. The current sub-tree provides `query`, a scriptable SQL +// runner against any Lakebase Postgres endpoint that does not require a +// system `psql` binary. +package postgrescmd + +import ( + "github.com/spf13/cobra" +) + +// New returns the root `postgres` experimental command. It is hidden by its +// experimental parent; the command itself is always visible once one of its +// subcommands is reached. +func New() *cobra.Command { + cmd := &cobra.Command{ + Use: "postgres", + Short: "Experimental Lakebase Postgres commands", + Long: `Experimental commands for interacting with Lakebase Postgres endpoints. + +These commands are still under development and may change without notice.`, + } + + cmd.AddCommand(newQueryCmd()) + return cmd +} diff --git a/experimental/postgres/cmd/connect.go b/experimental/postgres/cmd/connect.go new file mode 100644 index 00000000000..a0674b81ead --- /dev/null +++ b/experimental/postgres/cmd/connect.go @@ -0,0 +1,147 @@ +package postgrescmd + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// defaultConnectTimeout is the dial timeout for a single connect attempt. +// Lakebase autoscaling endpoints can be cold-starting; Postgres' own dial +// keeps trying within this window before giving up. +const defaultConnectTimeout = 120 * time.Second + +// connectConfig collects everything pgx needs to dial Postgres. Kept as a +// struct rather than passed through positional args because the pgx config +// has many fields and the call sites differ between code paths (production +// vs unit tests stubbing connectFunc). +type connectConfig struct { + Host string + Port int + Username string + Password string + Database string + ConnectTimeout time.Duration +} + +// retryConfig controls connect retry on idle/waking endpoints. MaxAttempts is +// the total number of attempts: 1 means no retry, 3 means up to two retries +// with backoff between. We use the count-of-attempts reading rather than +// count-of-retries to match libs/psql.RetryConfig.MaxRetries semantics, so +// behavior stays consistent across the two commands sharing a flag name. +type retryConfig struct { + MaxAttempts int + InitialDelay time.Duration + MaxDelay time.Duration +} + +// connectFunc is a seam for unit tests: production wires pgx.ConnectConfig, +// tests inject failures (DNS, auth, ctx-cancel mid-connect). We deliberately +// do not wrap *pgx.Conn behind an interface for query execution; that surface +// is exercised by integration tests against real Lakebase endpoints. +type connectFunc func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, error) + +// buildPgxConfig parses a base DSN to inherit pgx's TLS shape, then patches +// in the resolved values. The DSN-then-patch pattern is the recommended way +// to configure pgx for `sslmode=require` because building a pgx.ConnConfig +// by hand omits internal fields that the parser sets. +func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) { + cfg, err := pgx.ParseConfig("postgresql:///?sslmode=require") + if err != nil { + return nil, fmt.Errorf("parse pgx config: %w", err) + } + cfg.Host = c.Host + cfg.Port = uint16(c.Port) + cfg.User = c.Username + cfg.Password = c.Password + cfg.Database = c.Database + cfg.ConnectTimeout = c.ConnectTimeout + return cfg, nil +} + +// connectWithRetry dials Postgres, retrying on connect-time errors that +// indicate the endpoint is asleep or in the middle of a wake-up. Errors that +// cannot be improved by retrying (auth failures, permission errors, +// post-query errors) are returned immediately. +func connectWithRetry(ctx context.Context, cfg *pgx.ConnConfig, rc retryConfig, dial connectFunc) (*pgx.Conn, error) { + if rc.MaxAttempts < 1 { + rc.MaxAttempts = 1 + } + + delay := rc.InitialDelay + var lastErr error + + for attempt := 1; attempt <= rc.MaxAttempts; attempt++ { + if attempt > 1 { + cmdio.LogString(ctx, fmt.Sprintf("Connection attempt %d/%d failed, retrying in %v...", attempt-1, rc.MaxAttempts, delay)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + } + if rc.MaxDelay > 0 { + delay = min(delay*2, rc.MaxDelay) + } + } + + conn, err := dial(ctx, cfg) + if err == nil { + return conn, nil + } + lastErr = err + + if !isRetryableConnectError(err) { + return nil, err + } + log.Debugf(ctx, "retryable connect error on attempt %d: %v", attempt, err) + } + + return nil, fmt.Errorf("failed to connect after %d attempts: %w", rc.MaxAttempts, lastErr) +} + +// isRetryableConnectError classifies whether an error from the connect path +// is a transient "endpoint asleep / cold-starting" failure. +// +// Retryable: +// - net.OpError with Op == "dial" (DNS resolution, TCP connect refused, +// host unreachable). The "endpoint asleep" cases. +// - pgconn.ConnectError that wraps a retryable network error. +// - Postgres connection-establishment SQLSTATE codes (08xxx). Lakebase +// emits these during cold-start. +// +// Not retryable: auth errors (28xxx), permission errors (42501), +// context cancellation/deadlines, anything after Query has been issued +// (caller never passes that to us; we only run before Query). +func isRetryableConnectError(err error) bool { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + // 08xxx is the connection_exception class. + if len(pgErr.Code) == 5 && pgErr.Code[:2] == "08" { + return true + } + return false + } + + var connectErr *pgconn.ConnectError + if errors.As(err, &connectErr) { + return isRetryableConnectError(connectErr.Unwrap()) + } + + var opErr *net.OpError + if errors.As(err, &opErr) { + return opErr.Op == "dial" + } + + return false +} diff --git a/experimental/postgres/cmd/connect_test.go b/experimental/postgres/cmd/connect_test.go new file mode 100644 index 00000000000..0f7614b1f31 --- /dev/null +++ b/experimental/postgres/cmd/connect_test.go @@ -0,0 +1,149 @@ +package postgrescmd + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/databricks/cli/libs/cmdio" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testCtx(t *testing.T) context.Context { + return cmdio.MockDiscard(t.Context()) +} + +func TestIsRetryableConnectError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "dial error", + err: &net.OpError{Op: "dial", Err: errors.New("connection refused")}, + want: true, + }, + { + name: "non-dial net.OpError", + err: &net.OpError{Op: "read", Err: errors.New("oops")}, + want: false, + }, + { + name: "08006 connection failure", + err: &pgconn.PgError{Code: "08006", Message: "connection failure"}, + want: true, + }, + { + name: "08001 cannot establish", + err: &pgconn.PgError{Code: "08001", Message: "sqlclient unable to establish sqlconnection"}, + want: true, + }, + { + name: "28000 invalid auth", + err: &pgconn.PgError{Code: "28000", Message: "invalid authorization specification"}, + want: false, + }, + { + name: "28P01 invalid password", + err: &pgconn.PgError{Code: "28P01", Message: "invalid password"}, + want: false, + }, + { + name: "42501 insufficient privilege", + err: &pgconn.PgError{Code: "42501", Message: "permission denied"}, + want: false, + }, + { + name: "context cancelled", + err: context.Canceled, + want: false, + }, + { + name: "context deadline exceeded", + err: context.DeadlineExceeded, + want: false, + }, + { + name: "nil error never retryable", + err: nil, + want: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, isRetryableConnectError(tc.err)) + }) + } +} + +func TestConnectWithRetry_RespectsMaxAttempts(t *testing.T) { + ctx := testCtx(t) + calls := 0 + dialErr := &pgconn.PgError{Code: "08006"} + dial := func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, error) { + calls++ + return nil, dialErr + } + cfg := &pgx.ConnConfig{} + rc := retryConfig{MaxAttempts: 3, InitialDelay: 0, MaxDelay: 0} + + _, err := connectWithRetry(ctx, cfg, rc, dial) + require.Error(t, err) + assert.Equal(t, 3, calls, "expected 3 attempts (1 initial + 2 retries)") +} + +func TestConnectWithRetry_StopsOnNonRetryable(t *testing.T) { + ctx := testCtx(t) + calls := 0 + authErr := &pgconn.PgError{Code: "28P01"} + dial := func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, error) { + calls++ + return nil, authErr + } + cfg := &pgx.ConnConfig{} + rc := retryConfig{MaxAttempts: 3, InitialDelay: 0} + + _, err := connectWithRetry(ctx, cfg, rc, dial) + require.Error(t, err) + assert.Equal(t, 1, calls, "auth errors should not retry") +} + +func TestConnectWithRetry_ZeroMaxAttemptsTreatedAsOne(t *testing.T) { + ctx := testCtx(t) + calls := 0 + dial := func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, error) { + calls++ + return nil, errors.New("nope") + } + cfg := &pgx.ConnConfig{} + rc := retryConfig{MaxAttempts: 0, InitialDelay: time.Millisecond} + + _, err := connectWithRetry(ctx, cfg, rc, dial) + require.Error(t, err) + assert.Equal(t, 1, calls) +} + +func TestBuildPgxConfig(t *testing.T) { + cfg, err := buildPgxConfig(connectConfig{ + Host: "host.example.com", + Port: 5432, + Username: "user", + Password: "secret", + Database: "db", + ConnectTimeout: 30 * time.Second, + }) + require.NoError(t, err) + assert.Equal(t, "host.example.com", cfg.Host) + assert.Equal(t, uint16(5432), cfg.Port) + assert.Equal(t, "user", cfg.User) + assert.Equal(t, "secret", cfg.Password) + assert.Equal(t, "db", cfg.Database) + assert.Equal(t, 30*time.Second, cfg.ConnectTimeout) +} diff --git a/experimental/postgres/cmd/execute.go b/experimental/postgres/cmd/execute.go new file mode 100644 index 00000000000..c29f7ce59d6 --- /dev/null +++ b/experimental/postgres/cmd/execute.go @@ -0,0 +1,62 @@ +package postgrescmd + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" +) + +// executeOne runs a single SQL statement against an open connection and +// captures the result in a queryResult. +// +// We pass QueryExecModeExec explicitly (not the pgx default +// QueryExecModeCacheStatement) for two reasons: +// +// 1. Statement caching has no benefit for a one-shot CLI: the connection is +// closed at the end of the command, so the cached prepared statement +// never gets reused. +// 2. Exec mode uses Postgres' extended-protocol "exec" path with text-format +// result columns. That makes rendering canonical-Postgres-text output +// (PR 1) and CSV (later PR) straightforward; the cache mode defaults to +// binary and we'd be reformatting back to text. +// +// QueryExecModeExec still uses extended protocol with a single statement and +// no implicit transaction wrap, so transaction-disallowed DDL like +// `CREATE DATABASE` works. +func executeOne(ctx context.Context, conn *pgx.Conn, sql string) (*queryResult, error) { + rows, err := conn.Query(ctx, sql, pgx.QueryExecModeExec) + if err != nil { + return nil, fmt.Errorf("query failed: %w", err) + } + defer rows.Close() + + result := &queryResult{SQL: sql} + + fields := rows.FieldDescriptions() + if len(fields) > 0 { + result.Columns = make([]string, len(fields)) + for i, f := range fields { + result.Columns[i] = f.Name + } + } + + for rows.Next() { + raw := rows.RawValues() + row := make([]string, len(raw)) + for i, b := range raw { + if b == nil { + row[i] = "NULL" + continue + } + row[i] = string(b) + } + result.Rows = append(result.Rows, row) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("query failed: %w", err) + } + + result.CommandTag = rows.CommandTag().String() + return result, nil +} diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go new file mode 100644 index 00000000000..643aa496e84 --- /dev/null +++ b/experimental/postgres/cmd/query.go @@ -0,0 +1,133 @@ +package postgrescmd + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdio" + "github.com/jackc/pgx/v5" + "github.com/spf13/cobra" +) + +// defaultDatabase is the database name used when --database is not set. +// Lakebase Autoscaling and Provisioned both use this name as their default. +const defaultDatabase = "databricks_postgres" + +// queryFlags is the union of every flag the query command exposes. Lifted +// out of newQueryCmd so unit-tested helpers (resolveTarget, etc.) can take +// it directly without poking at cobra internals. +type queryFlags struct { + targetingFlags + database string + connectTimeout time.Duration + maxRetries int +} + +func newQueryCmd() *cobra.Command { + var f queryFlags + + cmd := &cobra.Command{ + Use: "query [SQL]", + Short: "Run a SQL statement against a Lakebase Postgres endpoint", + GroupID: "", + Long: `Execute a single SQL statement against a Lakebase Postgres endpoint and +render the result as text. + +Targeting (exactly one form required): + --target STRING Autoscaling resource path + (e.g. projects/foo/branches/main/endpoints/primary) + --project ID Autoscaling project ID + --branch ID Autoscaling branch ID (default: auto-select if exactly one) + --endpoint ID Autoscaling endpoint ID (default: auto-select if exactly one) + +This is an experimental command. The flag set, output shape, and supported +target kinds will expand in subsequent releases. + +Limitations (this release): + + - Single SQL statement per invocation (multi-statement support comes later). + - Text output only. JSON and CSV output come in a follow-up release. + - Only Lakebase Autoscaling endpoints are supported. Provisioned instance + support comes in a follow-up release; use 'databricks psql ' as a + workaround for now. + - No interactive REPL. 'databricks psql' continues to own that surface. + - Multi-statement strings (e.g. "SELECT 1; SELECT 2") are not supported. + - The OAuth token is generated once per invocation and is valid for 1h. + Queries longer than that fail with an auth error. +`, + Args: cobra.ExactArgs(1), + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + return runQuery(cmd.Context(), cmd, args[0], f) + }, + } + + cmd.Flags().StringVar(&f.target, "target", "", "Autoscaling resource path (e.g. projects/foo/branches/main/endpoints/primary)") + cmd.Flags().StringVar(&f.project, "project", "", "Autoscaling project ID") + cmd.Flags().StringVar(&f.branch, "branch", "", "Autoscaling branch ID (default: auto-select if exactly one)") + cmd.Flags().StringVar(&f.endpoint, "endpoint", "", "Autoscaling endpoint ID (default: auto-select if exactly one)") + cmd.Flags().StringVarP(&f.database, "database", "d", defaultDatabase, "Database name") + cmd.Flags().DurationVar(&f.connectTimeout, "connect-timeout", defaultConnectTimeout, "Connect timeout") + cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (1 disables retry)") + + cmd.MarkFlagsMutuallyExclusive("target", "project") + cmd.MarkFlagsMutuallyExclusive("target", "branch") + cmd.MarkFlagsMutuallyExclusive("target", "endpoint") + + return cmd +} + +// runQuery is the production entry point. It is split out from RunE so unit +// tests can call it directly with a stubbed connectFunc once we add seam-based +// tests in a later PR. +func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) error { + sql = strings.TrimSpace(sql) + if sql == "" { + return errors.New("no SQL provided") + } + if err := validateTargeting(f.targetingFlags); err != nil { + return err + } + + resolved, err := resolveTarget(ctx, f.targetingFlags) + if err != nil { + return err + } + + cmdio.LogString(ctx, fmt.Sprintf("Connecting to %s...", resolved.DisplayName)) + + pgxCfg, err := buildPgxConfig(connectConfig{ + Host: resolved.Host, + Port: 5432, + Username: resolved.Username, + Password: resolved.Token, + Database: f.database, + ConnectTimeout: f.connectTimeout, + }) + if err != nil { + return err + } + + rc := retryConfig{ + MaxAttempts: max(1, f.maxRetries), + InitialDelay: time.Second, + MaxDelay: 10 * time.Second, + } + + conn, err := connectWithRetry(ctx, pgxCfg, rc, pgx.ConnectConfig) + if err != nil { + return err + } + defer conn.Close(context.WithoutCancel(ctx)) + + result, err := executeOne(ctx, conn, sql) + if err != nil { + return err + } + + return renderText(cmd.OutOrStdout(), result) +} diff --git a/experimental/postgres/cmd/render.go b/experimental/postgres/cmd/render.go new file mode 100644 index 00000000000..ff923c4a92e --- /dev/null +++ b/experimental/postgres/cmd/render.go @@ -0,0 +1,74 @@ +package postgrescmd + +import ( + "fmt" + "io" + "strings" + "text/tabwriter" +) + +// queryResult is the rendered shape of a single SQL execution. PR 1 only +// renders text; later PRs add JSON and CSV against the same struct. +// +// columns is empty for command-only statements (INSERT, CREATE DATABASE, ...); +// rows is empty when no rows were returned (or for command-only statements). +type queryResult struct { + SQL string + // CommandTag is the Postgres command tag for the statement (e.g. "INSERT 0 5", + // "CREATE DATABASE"). Always set; used for command-only statements and as a + // trailer for rows-producing ones. + CommandTag string + Columns []string + Rows [][]string +} + +// IsRowsProducing reports whether the statement returned a row description. +// Determined at runtime via FieldDescriptions() rather than by parsing the +// leading SQL keyword: `INSERT ... RETURNING` and CTEs ending in a SELECT are +// rows-producing despite their leading keywords. +func (r *queryResult) IsRowsProducing() bool { + return len(r.Columns) > 0 +} + +// renderText writes a result in plain text. +// +// For rows-producing statements we use a tabwriter-aligned table followed by +// a `(N rows)` footer, mimicking psql's compact text shape. For command-only +// statements we just print the command tag. +// +// PR 1 always uses the static (buffered) shape. The interactive table viewer +// for >30 rows lands in a later PR alongside the multi-input output shapes. +func renderText(out io.Writer, r *queryResult) error { + if !r.IsRowsProducing() { + _, err := fmt.Fprintln(out, r.CommandTag) + return err + } + + tw := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0) + fmt.Fprintln(tw, strings.Join(r.Columns, "\t")) + fmt.Fprintln(tw, strings.Join(headerSeparator(r.Columns), "\t")) + for _, row := range r.Rows { + fmt.Fprintln(tw, strings.Join(row, "\t")) + } + if err := tw.Flush(); err != nil { + return err + } + + _, err := fmt.Fprintf(out, "(%d %s)\n", len(r.Rows), pluralize(len(r.Rows), "row", "rows")) + return err +} + +func headerSeparator(cols []string) []string { + out := make([]string, len(cols)) + for i, c := range cols { + out[i] = strings.Repeat("-", max(len(c), 3)) + } + return out +} + +func pluralize(n int, singular, plural string) string { + if n == 1 { + return singular + } + return plural +} diff --git a/experimental/postgres/cmd/render_test.go b/experimental/postgres/cmd/render_test.go new file mode 100644 index 00000000000..29aeb3c36fc --- /dev/null +++ b/experimental/postgres/cmd/render_test.go @@ -0,0 +1,67 @@ +package postgrescmd + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRenderText_RowsProducing(t *testing.T) { + r := &queryResult{ + Columns: []string{"id", "name"}, + Rows: [][]string{ + {"1", "alice"}, + {"2", "bob"}, + }, + CommandTag: "SELECT 2", + } + var buf bytes.Buffer + require.NoError(t, renderText(&buf, r)) + + assert.Equal(t, + "id name\n"+ + "--- ----\n"+ + "1 alice\n"+ + "2 bob\n"+ + "(2 rows)\n", + buf.String(), + ) +} + +func TestRenderText_SingleRow(t *testing.T) { + r := &queryResult{ + Columns: []string{"id"}, + Rows: [][]string{{"42"}}, + CommandTag: "SELECT 1", + } + var buf bytes.Buffer + require.NoError(t, renderText(&buf, r)) + assert.Contains(t, buf.String(), "(1 row)\n") +} + +func TestRenderText_Empty(t *testing.T) { + r := &queryResult{ + Columns: []string{"id", "name"}, + CommandTag: "SELECT 0", + } + var buf bytes.Buffer + require.NoError(t, renderText(&buf, r)) + assert.Contains(t, buf.String(), "(0 rows)\n") +} + +func TestRenderText_CommandOnly(t *testing.T) { + r := &queryResult{ + CommandTag: "INSERT 0 5", + } + var buf bytes.Buffer + require.NoError(t, renderText(&buf, r)) + assert.Equal(t, "INSERT 0 5\n", buf.String()) +} + +func TestQueryResultIsRowsProducing(t *testing.T) { + assert.False(t, (&queryResult{}).IsRowsProducing()) + assert.False(t, (&queryResult{CommandTag: "INSERT 0 1"}).IsRowsProducing()) + assert.True(t, (&queryResult{Columns: []string{"a"}}).IsRowsProducing()) +} diff --git a/experimental/postgres/cmd/targeting.go b/experimental/postgres/cmd/targeting.go new file mode 100644 index 00000000000..e8a17fadfce --- /dev/null +++ b/experimental/postgres/cmd/targeting.go @@ -0,0 +1,173 @@ +package postgrescmd + +import ( + "context" + "errors" + "fmt" + + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/lakebase/target" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/postgres" +) + +// resolvedTarget carries everything the query command needs to dial Postgres: +// the endpoint host (resolved through the SDK) and a short-lived OAuth token. +// `kind` records whether we resolved an autoscaling endpoint or a provisioned +// instance, so the caller can pick the right default database name and emit +// kind-appropriate logging. +type resolvedTarget struct { + Kind targetKind + Host string + Username string + Token string + // Display strings used only for human-readable logs / errors. + DisplayName string +} + +type targetKind int + +const ( + kindAutoscaling targetKind = iota + kindProvisioned +) + +// targetingFlags is the user-supplied targeting input. Exactly one of: +// - target (full path or instance name) +// - project (with optional branch and endpoint) +// +// must be set. Validated by validateTargeting before any SDK call. +type targetingFlags struct { + target string + project string + branch string + endpoint string +} + +func (f targetingFlags) hasGranular() bool { + return f.project != "" || f.branch != "" || f.endpoint != "" +} + +// validateTargeting enforces "exactly one targeting form" before any SDK call. +// Returns a typed error so the JSON envelope renderer (added in a later PR) +// can surface a structured error. +func validateTargeting(f targetingFlags) error { + switch { + case f.target == "" && !f.hasGranular(): + return errors.New("must specify --target or --project") + case f.target != "" && f.hasGranular(): + return errors.New("--target is mutually exclusive with --project, --branch, --endpoint") + case f.target == "" && f.project == "" && (f.branch != "" || f.endpoint != ""): + return errors.New("--project is required when using --branch or --endpoint") + } + return nil +} + +// resolveTarget translates the validated flags into a resolvedTarget. +// PR 1 supports autoscaling targeting only; provisioned support is added in +// the next PR. A provisioned-shaped --target returns a clear error pointing at +// the experimental status. +func resolveTarget(ctx context.Context, f targetingFlags) (*resolvedTarget, error) { + w := cmdctx.WorkspaceClient(ctx) + + switch { + case f.target != "" && target.IsAutoscalingPath(f.target): + spec, err := target.ParseAutoscalingPath(f.target) + if err != nil { + return nil, err + } + return resolveAutoscaling(ctx, w, spec) + + case f.target != "": + // Provisioned-shaped target. Out of scope for this PR; will be wired in + // the follow-up PR alongside JSON/CSV output. + return nil, errors.New("provisioned instances are not yet supported by this experimental command; use 'databricks psql ' for now") + + default: + spec := target.AutoscalingSpec{ + ProjectID: f.project, + BranchID: f.branch, + EndpointID: f.endpoint, + } + return resolveAutoscaling(ctx, w, spec) + } +} + +// resolveAutoscaling expands a partial spec into a fully-resolved endpoint and +// issues a short-lived OAuth token. Missing branch/endpoint IDs are +// auto-selected when exactly one candidate exists; ambiguity propagates as an +// AmbiguousError with the list of choices. +func resolveAutoscaling(ctx context.Context, w *databricks.WorkspaceClient, spec target.AutoscalingSpec) (*resolvedTarget, error) { + if spec.ProjectID == "" { + var err error + spec.ProjectID, err = target.AutoSelectProject(ctx, w) + if err != nil { + return nil, err + } + } + + project, err := target.GetProject(ctx, w, spec.ProjectID) + if err != nil { + return nil, fmt.Errorf("failed to get project: %w", err) + } + + if spec.BranchID == "" { + spec.BranchID, err = target.AutoSelectBranch(ctx, w, project.Name) + if err != nil { + return nil, err + } + } + + if spec.EndpointID == "" { + branchName := project.Name + "/branches/" + spec.BranchID + spec.EndpointID, err = target.AutoSelectEndpoint(ctx, w, branchName) + if err != nil { + return nil, err + } + } + + endpoint, err := target.GetEndpoint(ctx, w, spec.ProjectID, spec.BranchID, spec.EndpointID) + if err != nil { + return nil, fmt.Errorf("failed to get endpoint: %w", err) + } + + if err := checkEndpointReady(endpoint); err != nil { + return nil, err + } + + user, err := w.CurrentUser.Me(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get current user: %w", err) + } + + token, err := target.AutoscalingCredential(ctx, w, endpoint.Name) + if err != nil { + return nil, err + } + + return &resolvedTarget{ + Kind: kindAutoscaling, + Host: endpoint.Status.Hosts.Host, + Username: user.UserName, + Token: token, + DisplayName: endpoint.Name, + }, nil +} + +// checkEndpointReady returns an error if the endpoint is not in a connectable +// state. Idle endpoints are considered connectable (Lakebase wakes them on +// dial); the connect retry loop handles the wake-up window. +func checkEndpointReady(endpoint *postgres.Endpoint) error { + if endpoint.Status == nil { + return errors.New("endpoint status is not available") + } + if endpoint.Status.Hosts == nil || endpoint.Status.Hosts.Host == "" { + return errors.New("endpoint host information is not available") + } + switch endpoint.Status.CurrentState { + case postgres.EndpointStatusStateActive, postgres.EndpointStatusStateIdle: + return nil + default: + return fmt.Errorf("endpoint is not ready for accepting connections (state: %s)", endpoint.Status.CurrentState) + } +} diff --git a/experimental/postgres/cmd/targeting_test.go b/experimental/postgres/cmd/targeting_test.go new file mode 100644 index 00000000000..dfdab0d405c --- /dev/null +++ b/experimental/postgres/cmd/targeting_test.go @@ -0,0 +1,81 @@ +package postgrescmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateTargeting(t *testing.T) { + tests := []struct { + name string + flags targetingFlags + wantErr string + }{ + { + name: "neither form", + flags: targetingFlags{}, + wantErr: "must specify --target or --project", + }, + { + name: "only target", + flags: targetingFlags{ + target: "projects/foo", + }, + }, + { + name: "only project", + flags: targetingFlags{ + project: "foo", + }, + }, + { + name: "project and branch", + flags: targetingFlags{ + project: "foo", + branch: "main", + }, + }, + { + name: "project, branch, endpoint", + flags: targetingFlags{ + project: "foo", + branch: "main", + endpoint: "primary", + }, + }, + { + name: "target and project both set", + flags: targetingFlags{ + target: "projects/foo", + project: "foo", + }, + wantErr: "mutually exclusive", + }, + { + name: "branch without project", + flags: targetingFlags{ + branch: "main", + }, + wantErr: "--project is required when using --branch or --endpoint", + }, + { + name: "endpoint without project", + flags: targetingFlags{ + endpoint: "primary", + }, + wantErr: "--project is required when using --branch or --endpoint", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validateTargeting(tc.flags) + if tc.wantErr != "" { + assert.ErrorContains(t, err, tc.wantErr) + return + } + assert.NoError(t, err) + }) + } +} diff --git a/go.mod b/go.mod index f376aa0a98d..170414de39f 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/hashicorp/terraform-exec v0.25.0 // MPL-2.0 github.com/hashicorp/terraform-json v0.27.2 // MPL-2.0 github.com/hexops/gotextdiff v1.0.3 // BSD-3-Clause + github.com/jackc/pgx/v5 v5.9.1 // MIT github.com/manifoldco/promptui v0.9.0 // BSD-3-Clause github.com/mattn/go-isatty v0.0.20 // MIT github.com/nwidger/jsoncolor v0.3.2 // MIT @@ -80,6 +81,8 @@ require ( github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-retryablehttp v0.7.8 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-localereader v0.0.1 // indirect diff --git a/go.sum b/go.sum index f9181b898a2..715807887cd 100644 --- a/go.sum +++ b/go.sum @@ -147,6 +147,14 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc= +github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= @@ -213,7 +221,9 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= From edbd6bef6bacee89210398b51c3e37a8582ec3fd Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 10:45:56 +0200 Subject: [PATCH 03/25] Address review feedback on PR 1 - Replace exported PathSegmentProjects/ExtractID with focused helpers (ProjectIDFromName/BranchIDFromName/EndpointIDFromName); keeps SDK literals out of call sites. - Type AmbiguousError.Kind as a typed enum (KindProject/Branch/Endpoint/Instance) so producers and the pluralisation switch stay in sync. - Stop setting Choice.DisplayName when it equals the ID; Error() relies on empty-suppression rather than mixed empty/equal-to-ID checks. - Add 57P03 (cannot_connect_now) to the connect-retry allow-list. Postgres emits this during server startup and Lakebase autoscaling can plausibly return it during the wake-up handshake. Tests exercise 57P03/57P01/57014 to lock the boundary. - Require --branch when --endpoint is set. The auto-select-then-look-up flow produces confusing errors when the auto-selected branch does not contain the requested endpoint, and this command is non-interactive so asking the user to be explicit is friendlier. - Reject --max-retries < 1 explicitly instead of silently clamping. Help text already advertised the constraint; matching it at validation time is consistent with the repo's "reject incompatible inputs early" rule. - Harmonise the "endpoint is not ready" error in cmd/psql to include the state, matching the experimental command and giving operators something to act on. - Restore comments removed during the cmd/psql refactor and add a breadcrumb at the GetProvisioned call site about the Name patch. - Add doc comments to AutoSelect* helpers documenting the returned string shape (trailing ID for autoscaling vs full name for provisioned). - Reject trailing components after endpoint in ParseAutoscalingPath; new acceptance test in cmd/psql exercises this. - Drop dead GroupID: "" assignment. Co-authored-by: Isaac --- .../postgres/query/argument-errors/output.txt | 8 +++ .../postgres/query/argument-errors/script | 6 ++ .../cmd/psql/argument-errors/output.txt | 4 ++ acceptance/cmd/psql/argument-errors/script | 3 + acceptance/cmd/psql/postgres/output.txt | 2 +- cmd/psql/psql.go | 4 +- cmd/psql/psql_autoscaling.go | 2 +- cmd/psql/psql_provisioned.go | 3 + experimental/postgres/cmd/connect.go | 14 ++-- experimental/postgres/cmd/connect_test.go | 30 ++++---- experimental/postgres/cmd/query.go | 12 ++-- experimental/postgres/cmd/targeting.go | 7 ++ experimental/postgres/cmd/targeting_test.go | 8 +++ libs/lakebase/target/autoscaling.go | 51 +++++++------ libs/lakebase/target/provisioned.go | 10 +-- libs/lakebase/target/target.go | 72 ++++++++++++++----- libs/lakebase/target/target_test.go | 39 ++++++---- 17 files changed, 188 insertions(+), 87 deletions(-) diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt index 59ddbfedc6e..c071466a1e3 100644 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt @@ -27,6 +27,14 @@ Error: --project is required when using --branch or --endpoint >>> musterr [CLI] experimental postgres query --endpoint primary SELECT 1 Error: --project is required when using --branch or --endpoint +=== Endpoint without branch should error: +>>> musterr [CLI] experimental postgres query --project foo --endpoint primary SELECT 1 +Error: --branch is required when using --endpoint + +=== Max-retries 0 should error: +>>> musterr [CLI] experimental postgres query --project foo --branch main --max-retries 0 SELECT 1 +Error: --max-retries must be at least 1; got 0 + === Provisioned-shaped target should error pointing at psql: >>> musterr [CLI] experimental postgres query --target my-instance SELECT 1 Error: provisioned instances are not yet supported by this experimental command; use 'databricks psql ' for now diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/script b/acceptance/cmd/experimental/postgres/query/argument-errors/script index 5874c843a03..8d64bf307ed 100644 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/script +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/script @@ -19,6 +19,12 @@ trace musterr $CLI experimental postgres query --branch main "SELECT 1" title "Endpoint without project should error:" trace musterr $CLI experimental postgres query --endpoint primary "SELECT 1" +title "Endpoint without branch should error:" +trace musterr $CLI experimental postgres query --project foo --endpoint primary "SELECT 1" + +title "Max-retries 0 should error:" +trace musterr $CLI experimental postgres query --project foo --branch main --max-retries 0 "SELECT 1" + title "Provisioned-shaped target should error pointing at psql:" trace musterr $CLI experimental postgres query --target my-instance "SELECT 1" diff --git a/acceptance/cmd/psql/argument-errors/output.txt b/acceptance/cmd/psql/argument-errors/output.txt index 35da5961dec..cbf6c093b21 100644 --- a/acceptance/cmd/psql/argument-errors/output.txt +++ b/acceptance/cmd/psql/argument-errors/output.txt @@ -59,6 +59,10 @@ Error: invalid resource path: missing branch ID >>> musterr [CLI] psql projects/my-project/branches/main/endpoints/ Error: invalid resource path: missing endpoint ID +=== Trailing components after endpoint should error: +>>> musterr [CLI] psql projects/my-project/branches/main/endpoints/primary/extra +Error: invalid resource path: trailing components after endpoint: projects/my-project/branches/main/endpoints/primary/extra + === Provisioned flag with --project should error: >>> musterr [CLI] psql --provisioned --project foo Error: cannot use --project, --branch, or --endpoint flags with --provisioned diff --git a/acceptance/cmd/psql/argument-errors/script b/acceptance/cmd/psql/argument-errors/script index 7806efb0744..7db1cdbd271 100644 --- a/acceptance/cmd/psql/argument-errors/script +++ b/acceptance/cmd/psql/argument-errors/script @@ -38,6 +38,9 @@ trace musterr $CLI psql projects/my-project/branches/ title "Invalid path with missing endpoint ID should error:" trace musterr $CLI psql projects/my-project/branches/main/endpoints/ +title "Trailing components after endpoint should error:" +trace musterr $CLI psql projects/my-project/branches/main/endpoints/primary/extra + title "Provisioned flag with --project should error:" trace musterr $CLI psql --provisioned --project foo diff --git a/acceptance/cmd/psql/postgres/output.txt b/acceptance/cmd/psql/postgres/output.txt index 5269553a0ce..8df91c6321c 100644 --- a/acceptance/cmd/psql/postgres/output.txt +++ b/acceptance/cmd/psql/postgres/output.txt @@ -50,7 +50,7 @@ PGSSLMODE=require Project: Init Project Branch: main Endpoint: init-ep -Error: endpoint is not ready for accepting connections +Error: endpoint is not ready for accepting connections (state: INIT) === Branch flag without project should fail: >>> musterr [CLI] psql --branch some-branch diff --git a/cmd/psql/psql.go b/cmd/psql/psql.go index e5cfaff5cff..9be7fb5c5df 100644 --- a/cmd/psql/psql.go +++ b/cmd/psql/psql.go @@ -257,7 +257,7 @@ func showSelectionAndConnect(ctx context.Context, retryConfig libpsql.RetryConfi }) } for _, proj := range projects { - displayName := target.ExtractID(proj.Name, target.PathSegmentProjects) + displayName := target.ProjectIDFromName(proj.Name) if proj.Status != nil && proj.Status.DisplayName != "" { displayName = proj.Status.DisplayName } @@ -278,7 +278,7 @@ func showSelectionAndConnect(ctx context.Context, retryConfig libpsql.RetryConfi } if after, ok := strings.CutPrefix(selected, "autoscaling:"); ok { projectName := after - projectID := target.ExtractID(projectName, target.PathSegmentProjects) + projectID := target.ProjectIDFromName(projectName) return connectAutoscaling(ctx, projectID, "", "", retryConfig, extraArgs) } diff --git a/cmd/psql/psql_autoscaling.go b/cmd/psql/psql_autoscaling.go index 4273dad3b50..a4c3293cc18 100644 --- a/cmd/psql/psql_autoscaling.go +++ b/cmd/psql/psql_autoscaling.go @@ -61,7 +61,7 @@ func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID str case postgres.EndpointStatusStateIdle: suffix = " (idle, waking up)" default: - return errors.New("endpoint is not ready for accepting connections") + return fmt.Errorf("endpoint is not ready for accepting connections (state: %s)", state) } cmdio.LogString(ctx, fmt.Sprintf("Connecting to %s endpoint%s...", endpointType, suffix)) diff --git a/cmd/psql/psql_provisioned.go b/cmd/psql/psql_provisioned.go index 9ea88def5ce..c7208906aa8 100644 --- a/cmd/psql/psql_provisioned.go +++ b/cmd/psql/psql_provisioned.go @@ -58,6 +58,7 @@ func connectProvisioned(ctx context.Context, instanceName string, retryConfig li // resolveInstance resolves an instance name to a full instance object. // If instanceName is empty, prompts the user to select one. func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanceName string) (*database.DatabaseInstance, error) { + // If instance not specified, select one if instanceName == "" { var err error instanceName, err = selectInstanceID(ctx, w) @@ -66,6 +67,8 @@ func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanc } } + // target.GetProvisioned patches Name on the response; the SDK's + // GetDatabaseInstance does not always populate it. instance, err := target.GetProvisioned(ctx, w, instanceName) if err != nil { return nil, err diff --git a/experimental/postgres/cmd/connect.go b/experimental/postgres/cmd/connect.go index a0674b81ead..920f02e932f 100644 --- a/experimental/postgres/cmd/connect.go +++ b/experimental/postgres/cmd/connect.go @@ -70,11 +70,10 @@ func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) { // indicate the endpoint is asleep or in the middle of a wake-up. Errors that // cannot be improved by retrying (auth failures, permission errors, // post-query errors) are returned immediately. +// +// MaxAttempts must be >= 1 (caller validates). 1 means a single attempt +// with no retries. func connectWithRetry(ctx context.Context, cfg *pgx.ConnConfig, rc retryConfig, dial connectFunc) (*pgx.Conn, error) { - if rc.MaxAttempts < 1 { - rc.MaxAttempts = 1 - } - delay := rc.InitialDelay var lastErr error @@ -115,6 +114,10 @@ func connectWithRetry(ctx context.Context, cfg *pgx.ConnConfig, rc retryConfig, // - pgconn.ConnectError that wraps a retryable network error. // - Postgres connection-establishment SQLSTATE codes (08xxx). Lakebase // emits these during cold-start. +// - Postgres "cannot_connect_now" (57P03), which Postgres returns during +// server startup ("the database system is starting up"). Plausibly emitted +// during the wake-up handshake window. We do NOT broaden to all of class 57: +// 57P01/57P02 are admin shutdowns (debatable) and 57014 is query_canceled. // // Not retryable: auth errors (28xxx), permission errors (42501), // context cancellation/deadlines, anything after Query has been issued @@ -130,6 +133,9 @@ func isRetryableConnectError(err error) bool { if len(pgErr.Code) == 5 && pgErr.Code[:2] == "08" { return true } + if pgErr.Code == "57P03" { + return true + } return false } diff --git a/experimental/postgres/cmd/connect_test.go b/experimental/postgres/cmd/connect_test.go index 0f7614b1f31..d58fc52cc74 100644 --- a/experimental/postgres/cmd/connect_test.go +++ b/experimental/postgres/cmd/connect_test.go @@ -44,6 +44,21 @@ func TestIsRetryableConnectError(t *testing.T) { err: &pgconn.PgError{Code: "08001", Message: "sqlclient unable to establish sqlconnection"}, want: true, }, + { + name: "57P03 cannot_connect_now", + err: &pgconn.PgError{Code: "57P03", Message: "the database system is starting up"}, + want: true, + }, + { + name: "57P01 admin shutdown not retryable", + err: &pgconn.PgError{Code: "57P01"}, + want: false, + }, + { + name: "57014 query_canceled not retryable", + err: &pgconn.PgError{Code: "57014"}, + want: false, + }, { name: "28000 invalid auth", err: &pgconn.PgError{Code: "28000", Message: "invalid authorization specification"}, @@ -115,21 +130,6 @@ func TestConnectWithRetry_StopsOnNonRetryable(t *testing.T) { assert.Equal(t, 1, calls, "auth errors should not retry") } -func TestConnectWithRetry_ZeroMaxAttemptsTreatedAsOne(t *testing.T) { - ctx := testCtx(t) - calls := 0 - dial := func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, error) { - calls++ - return nil, errors.New("nope") - } - cfg := &pgx.ConnConfig{} - rc := retryConfig{MaxAttempts: 0, InitialDelay: time.Millisecond} - - _, err := connectWithRetry(ctx, cfg, rc, dial) - require.Error(t, err) - assert.Equal(t, 1, calls) -} - func TestBuildPgxConfig(t *testing.T) { cfg, err := buildPgxConfig(connectConfig{ Host: "host.example.com", diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index 643aa496e84..fe5cc528ea7 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -31,9 +31,8 @@ func newQueryCmd() *cobra.Command { var f queryFlags cmd := &cobra.Command{ - Use: "query [SQL]", - Short: "Run a SQL statement against a Lakebase Postgres endpoint", - GroupID: "", + Use: "query [SQL]", + Short: "Run a SQL statement against a Lakebase Postgres endpoint", Long: `Execute a single SQL statement against a Lakebase Postgres endpoint and render the result as text. @@ -72,7 +71,7 @@ Limitations (this release): cmd.Flags().StringVar(&f.endpoint, "endpoint", "", "Autoscaling endpoint ID (default: auto-select if exactly one)") cmd.Flags().StringVarP(&f.database, "database", "d", defaultDatabase, "Database name") cmd.Flags().DurationVar(&f.connectTimeout, "connect-timeout", defaultConnectTimeout, "Connect timeout") - cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (1 disables retry)") + cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (must be >= 1; 1 disables retry)") cmd.MarkFlagsMutuallyExclusive("target", "project") cmd.MarkFlagsMutuallyExclusive("target", "branch") @@ -89,6 +88,9 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) if sql == "" { return errors.New("no SQL provided") } + if f.maxRetries < 1 { + return fmt.Errorf("--max-retries must be at least 1; got %d", f.maxRetries) + } if err := validateTargeting(f.targetingFlags); err != nil { return err } @@ -113,7 +115,7 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) } rc := retryConfig{ - MaxAttempts: max(1, f.maxRetries), + MaxAttempts: f.maxRetries, InitialDelay: time.Second, MaxDelay: 10 * time.Second, } diff --git a/experimental/postgres/cmd/targeting.go b/experimental/postgres/cmd/targeting.go index e8a17fadfce..5e72840f952 100644 --- a/experimental/postgres/cmd/targeting.go +++ b/experimental/postgres/cmd/targeting.go @@ -51,6 +51,11 @@ func (f targetingFlags) hasGranular() bool { // validateTargeting enforces "exactly one targeting form" before any SDK call. // Returns a typed error so the JSON envelope renderer (added in a later PR) // can surface a structured error. +// +// We require --branch when --endpoint is set: this command is non-interactive +// and scriptable, and the alternative (auto-select-then-look-up-endpoint) +// produces confusing errors when the resolved branch does not contain the +// requested endpoint. Asking the user to be explicit is friendlier. func validateTargeting(f targetingFlags) error { switch { case f.target == "" && !f.hasGranular(): @@ -59,6 +64,8 @@ func validateTargeting(f targetingFlags) error { return errors.New("--target is mutually exclusive with --project, --branch, --endpoint") case f.target == "" && f.project == "" && (f.branch != "" || f.endpoint != ""): return errors.New("--project is required when using --branch or --endpoint") + case f.endpoint != "" && f.branch == "": + return errors.New("--branch is required when using --endpoint") } return nil } diff --git a/experimental/postgres/cmd/targeting_test.go b/experimental/postgres/cmd/targeting_test.go index dfdab0d405c..62f43d22496 100644 --- a/experimental/postgres/cmd/targeting_test.go +++ b/experimental/postgres/cmd/targeting_test.go @@ -66,6 +66,14 @@ func TestValidateTargeting(t *testing.T) { }, wantErr: "--project is required when using --branch or --endpoint", }, + { + name: "endpoint with project but no branch", + flags: targetingFlags{ + project: "foo", + endpoint: "primary", + }, + wantErr: "--branch is required when using --endpoint", + }, } for _, tc := range tests { diff --git a/libs/lakebase/target/autoscaling.go b/libs/lakebase/target/autoscaling.go index f1edef216d4..3e496611d6b 100644 --- a/libs/lakebase/target/autoscaling.go +++ b/libs/lakebase/target/autoscaling.go @@ -26,20 +26,23 @@ func ListEndpoints(ctx context.Context, w *databricks.WorkspaceClient, branchNam return w.Postgres.ListEndpointsAll(ctx, postgres.ListEndpointsRequest{Parent: branchName}) } -// GetProject fetches a single project by ID. +// GetProject fetches a single project by ID. Unlike GetProvisioned, the +// Postgres autoscaling API populates the Name field on the response so we do +// not need to patch it. func GetProject(ctx context.Context, w *databricks.WorkspaceClient, projectID string) (*postgres.Project, error) { - return w.Postgres.GetProject(ctx, postgres.GetProjectRequest{Name: PathSegmentProjects + "/" + projectID}) + return w.Postgres.GetProject(ctx, postgres.GetProjectRequest{Name: pathSegmentProjects + "/" + projectID}) } -// GetEndpoint fetches a single endpoint by ID, given its parent IDs. +// GetEndpoint fetches a single endpoint by ID, given its parent IDs. Unlike +// GetProvisioned, the Postgres autoscaling API populates the Name field. func GetEndpoint(ctx context.Context, w *databricks.WorkspaceClient, projectID, branchID, endpointID string) (*postgres.Endpoint, error) { name := fmt.Sprintf("projects/%s/branches/%s/endpoints/%s", projectID, branchID, endpointID) return w.Postgres.GetEndpoint(ctx, postgres.GetEndpointRequest{Name: name}) } -// AutoSelectProject returns the only project in the workspace, or an -// AmbiguousError carrying the choices if there are multiple. Returns a plain -// error if there are no projects. +// AutoSelectProject returns the trailing project ID (e.g. "foo", not +// "projects/foo") if exactly one project exists. Returns an *AmbiguousError +// carrying the choices if there are multiple, or a plain error if there are none. func AutoSelectProject(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { projects, err := ListProjects(ctx, w) if err != nil { @@ -49,23 +52,24 @@ func AutoSelectProject(ctx context.Context, w *databricks.WorkspaceClient) (stri return "", errors.New("no Lakebase Autoscaling projects found in workspace") } if len(projects) == 1 { - return ExtractID(projects[0].Name, PathSegmentProjects), nil + return extractID(projects[0].Name, pathSegmentProjects), nil } choices := make([]Choice, 0, len(projects)) for _, p := range projects { - id := ExtractID(p.Name, PathSegmentProjects) - display := id - if p.Status != nil && p.Status.DisplayName != "" { + id := extractID(p.Name, pathSegmentProjects) + var display string + if p.Status != nil && p.Status.DisplayName != "" && p.Status.DisplayName != id { display = p.Status.DisplayName } choices = append(choices, Choice{ID: id, DisplayName: display}) } - return "", &AmbiguousError{Kind: "project", FlagHint: "--project", Choices: choices} + return "", &AmbiguousError{Kind: KindProject, FlagHint: "--project", Choices: choices} } -// AutoSelectBranch returns the only branch under projectName, or an -// AmbiguousError if there are multiple. +// AutoSelectBranch returns the trailing branch ID under projectName if +// exactly one branch exists. Returns an *AmbiguousError if there are multiple. +// projectName is the SDK resource name (e.g. "projects/foo"). func AutoSelectBranch(ctx context.Context, w *databricks.WorkspaceClient, projectName string) (string, error) { branches, err := ListBranches(ctx, w, projectName) if err != nil { @@ -75,19 +79,20 @@ func AutoSelectBranch(ctx context.Context, w *databricks.WorkspaceClient, projec return "", errors.New("no branches found in project") } if len(branches) == 1 { - return ExtractID(branches[0].Name, pathSegmentBranches), nil + return extractID(branches[0].Name, pathSegmentBranches), nil } choices := make([]Choice, 0, len(branches)) for _, b := range branches { - id := ExtractID(b.Name, pathSegmentBranches) - choices = append(choices, Choice{ID: id, DisplayName: id}) + id := extractID(b.Name, pathSegmentBranches) + choices = append(choices, Choice{ID: id}) } - return "", &AmbiguousError{Kind: "branch", Parent: projectName, FlagHint: "--branch", Choices: choices} + return "", &AmbiguousError{Kind: KindBranch, Parent: projectName, FlagHint: "--branch", Choices: choices} } -// AutoSelectEndpoint returns the only endpoint under branchName, or an -// AmbiguousError if there are multiple. +// AutoSelectEndpoint returns the trailing endpoint ID under branchName if +// exactly one endpoint exists. Returns an *AmbiguousError if there are multiple. +// branchName is the SDK resource name (e.g. "projects/foo/branches/bar"). func AutoSelectEndpoint(ctx context.Context, w *databricks.WorkspaceClient, branchName string) (string, error) { endpoints, err := ListEndpoints(ctx, w, branchName) if err != nil { @@ -97,15 +102,15 @@ func AutoSelectEndpoint(ctx context.Context, w *databricks.WorkspaceClient, bran return "", errors.New("no endpoints found in branch") } if len(endpoints) == 1 { - return ExtractID(endpoints[0].Name, pathSegmentEndpoints), nil + return extractID(endpoints[0].Name, pathSegmentEndpoints), nil } choices := make([]Choice, 0, len(endpoints)) for _, e := range endpoints { - id := ExtractID(e.Name, pathSegmentEndpoints) - choices = append(choices, Choice{ID: id, DisplayName: id}) + id := extractID(e.Name, pathSegmentEndpoints) + choices = append(choices, Choice{ID: id}) } - return "", &AmbiguousError{Kind: "endpoint", Parent: branchName, FlagHint: "--endpoint", Choices: choices} + return "", &AmbiguousError{Kind: KindEndpoint, Parent: branchName, FlagHint: "--endpoint", Choices: choices} } // AutoscalingCredential issues a short-lived OAuth token that can be used to diff --git a/libs/lakebase/target/provisioned.go b/libs/lakebase/target/provisioned.go index 773cc867ce0..261ef37a6a8 100644 --- a/libs/lakebase/target/provisioned.go +++ b/libs/lakebase/target/provisioned.go @@ -29,8 +29,10 @@ func GetProvisioned(ctx context.Context, w *databricks.WorkspaceClient, name str return instance, nil } -// AutoSelectProvisioned returns the only provisioned instance in the workspace, -// or an AmbiguousError if there are multiple. Returns a plain error if none. +// AutoSelectProvisioned returns the only provisioned instance's name (e.g. +// "my-instance"; the database SDK uses flat names, not the "projects/..." +// path shape used by autoscaling). Returns an *AmbiguousError if there are +// multiple, or a plain error if none. func AutoSelectProvisioned(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { instances, err := ListProvisionedInstances(ctx, w) if err != nil { @@ -45,9 +47,9 @@ func AutoSelectProvisioned(ctx context.Context, w *databricks.WorkspaceClient) ( choices := make([]Choice, 0, len(instances)) for _, inst := range instances { - choices = append(choices, Choice{ID: inst.Name, DisplayName: inst.Name}) + choices = append(choices, Choice{ID: inst.Name}) } - return "", &AmbiguousError{Kind: "instance", FlagHint: "--target", Choices: choices} + return "", &AmbiguousError{Kind: KindInstance, FlagHint: "--target", Choices: choices} } // ProvisionedCredential issues a short-lived OAuth token for the provisioned diff --git a/libs/lakebase/target/target.go b/libs/lakebase/target/target.go index d02c95903ce..f0fd2c069e3 100644 --- a/libs/lakebase/target/target.go +++ b/libs/lakebase/target/target.go @@ -11,9 +11,11 @@ import ( ) const ( - // PathSegmentProjects is the leading path segment that identifies an - // autoscaling resource path. Provisioned instance names never start with it. - PathSegmentProjects = "projects" + // pathSegmentProjects is the leading path segment that identifies an + // autoscaling resource path. Provisioned instance names never start with + // it. Use IsAutoscalingPath / ProjectIDFromName from outside this package + // instead of comparing the literal. + pathSegmentProjects = "projects" pathSegmentBranches = "branches" pathSegmentEndpoints = "endpoints" ) @@ -28,11 +30,27 @@ type AutoscalingSpec struct { // Choice is a single candidate returned alongside an AmbiguousError so callers // can either render the list to the user or prompt interactively. +// +// DisplayName is the optional friendlier label for the choice. Producers +// should leave it empty when no friendlier label exists; callers that prompt +// interactively can fall back to the ID. type Choice struct { ID string DisplayName string } +// AmbiguousKind is the typed enum for what an AmbiguousError refers to. A +// typed enum (vs raw string) keeps producers and the pluralisation switch in +// AmbiguousError.Error in sync. +type AmbiguousKind string + +const ( + KindProject AmbiguousKind = "project" + KindBranch AmbiguousKind = "branch" + KindEndpoint AmbiguousKind = "endpoint" + KindInstance AmbiguousKind = "instance" +) + // AmbiguousError is returned by AutoSelect* helpers when the SDK returns more // than one candidate and the caller did not specify which one to pick. // @@ -41,26 +59,27 @@ type Choice struct { // scriptable `postgres query` command) propagate it as a plain error: the // formatted message already enumerates the choices. type AmbiguousError struct { - // Kind identifies what was ambiguous: "project", "branch", or "endpoint". - Kind string + Kind AmbiguousKind // Parent is the SDK resource name that contained the ambiguity (e.g. // "projects/foo" when listing branches), or empty when listing projects. Parent string // FlagHint is the flag a user would set to disambiguate (e.g. "--branch"). FlagHint string - // Choices enumerates the candidates returned by the SDK. + // Choices enumerates the candidates returned by the SDK. DisplayName is + // only set when it carries information beyond ID; an empty DisplayName + // suppresses the parenthetical suffix in Error(). Choices []Choice } func (e *AmbiguousError) Error() string { - plural := map[string]string{ - "project": "projects", - "branch": "branches", - "endpoint": "endpoints", - "instance": "instances", + plural := map[AmbiguousKind]string{ + KindProject: "projects", + KindBranch: "branches", + KindEndpoint: "endpoints", + KindInstance: "instances", }[e.Kind] if plural == "" { - plural = e.Kind + plural = string(e.Kind) } var sb strings.Builder @@ -72,7 +91,7 @@ func (e *AmbiguousError) Error() string { for _, c := range e.Choices { sb.WriteString("\n - ") sb.WriteString(c.ID) - if c.DisplayName != "" && c.DisplayName != c.ID { + if c.DisplayName != "" { fmt.Fprintf(&sb, " (%s)", c.DisplayName) } } @@ -90,7 +109,7 @@ func (e *AmbiguousError) Error() string { func ParseAutoscalingPath(input string) (AutoscalingSpec, error) { parts := strings.Split(input, "/") - if len(parts) < 2 || parts[0] != PathSegmentProjects { + if len(parts) < 2 || parts[0] != pathSegmentProjects { return AutoscalingSpec{}, fmt.Errorf("invalid resource path: %s", input) } if parts[1] == "" { @@ -125,10 +144,10 @@ func ParseAutoscalingPath(input string) (AutoscalingSpec, error) { return spec, nil } -// ExtractID returns the value following component in a resource name. -// ExtractID("projects/foo/branches/bar", "branches") returns "bar". +// extractID returns the value following component in a resource name. +// extractID("projects/foo/branches/bar", "branches") returns "bar". // Returns the original name unchanged if component is not found. -func ExtractID(name, component string) string { +func extractID(name, component string) string { parts := strings.Split(name, "/") for i := range len(parts) - 1 { if parts[i] == component { @@ -138,8 +157,25 @@ func ExtractID(name, component string) string { return name } +// ProjectIDFromName extracts the project ID from a fully-qualified +// SDK resource name like "projects/foo" or "projects/foo/branches/bar". +// Returns the input unchanged if the name does not contain a "projects/" segment. +func ProjectIDFromName(name string) string { + return extractID(name, pathSegmentProjects) +} + +// BranchIDFromName extracts the branch ID from an SDK resource name. +func BranchIDFromName(name string) string { + return extractID(name, pathSegmentBranches) +} + +// EndpointIDFromName extracts the endpoint ID from an SDK resource name. +func EndpointIDFromName(name string) string { + return extractID(name, pathSegmentEndpoints) +} + // IsAutoscalingPath reports whether s is an autoscaling resource path // (i.e. starts with "projects/"). Provisioned instance names never do. func IsAutoscalingPath(s string) bool { - return strings.HasPrefix(s, PathSegmentProjects+"/") + return strings.HasPrefix(s, pathSegmentProjects+"/") } diff --git a/libs/lakebase/target/target_test.go b/libs/lakebase/target/target_test.go index 4b4a763c122..f502cf6e70c 100644 --- a/libs/lakebase/target/target_test.go +++ b/libs/lakebase/target/target_test.go @@ -64,6 +64,16 @@ func TestParseAutoscalingPath(t *testing.T) { input: "projects/foo/branches/bar/endpoints/baz/extra", wantErr: "trailing components after endpoint", }, + { + name: "empty input", + input: "", + wantErr: "invalid resource path", + }, + { + name: "single slash", + input: "/", + wantErr: "invalid resource path", + }, } for _, tc := range tests { @@ -80,11 +90,12 @@ func TestParseAutoscalingPath(t *testing.T) { } } -func TestExtractID(t *testing.T) { - assert.Equal(t, "bar", ExtractID("projects/foo/branches/bar", "branches")) - assert.Equal(t, "foo", ExtractID("projects/foo", "projects")) - assert.Equal(t, "baz", ExtractID("projects/foo/branches/bar/endpoints/baz", "endpoints")) - assert.Equal(t, "no-component", ExtractID("no-component", "missing")) +func TestIDFromName(t *testing.T) { + assert.Equal(t, "foo", ProjectIDFromName("projects/foo")) + assert.Equal(t, "foo", ProjectIDFromName("projects/foo/branches/bar")) + assert.Equal(t, "bar", BranchIDFromName("projects/foo/branches/bar")) + assert.Equal(t, "bar", BranchIDFromName("projects/foo/branches/bar/endpoints/baz")) + assert.Equal(t, "baz", EndpointIDFromName("projects/foo/branches/bar/endpoints/baz")) } func TestIsAutoscalingPath(t *testing.T) { @@ -96,14 +107,14 @@ func TestIsAutoscalingPath(t *testing.T) { } func TestAmbiguousErrorMessage(t *testing.T) { - t.Run("with parent", func(t *testing.T) { + t.Run("with parent, no display names", func(t *testing.T) { err := &AmbiguousError{ - Kind: "branch", + Kind: KindBranch, Parent: "projects/foo", FlagHint: "--branch", Choices: []Choice{ - {ID: "main", DisplayName: "main"}, - {ID: "feature-x", DisplayName: "feature-x"}, + {ID: "main"}, + {ID: "feature-x"}, }, } assert.Equal(t, @@ -112,13 +123,13 @@ func TestAmbiguousErrorMessage(t *testing.T) { ) }) - t.Run("without parent", func(t *testing.T) { + t.Run("without parent, mixed display names", func(t *testing.T) { err := &AmbiguousError{ - Kind: "project", + Kind: KindProject, FlagHint: "--project", Choices: []Choice{ {ID: "alpha", DisplayName: "Alpha Project"}, - {ID: "beta", DisplayName: "beta"}, + {ID: "beta"}, }, } assert.Equal(t, @@ -129,8 +140,8 @@ func TestAmbiguousErrorMessage(t *testing.T) { t.Run("errors.As", func(t *testing.T) { var amb *AmbiguousError - err := error(&AmbiguousError{Kind: "endpoint", FlagHint: "--endpoint"}) + err := error(&AmbiguousError{Kind: KindEndpoint, FlagHint: "--endpoint"}) assert.ErrorAs(t, err, &amb) - assert.Equal(t, "endpoint", amb.Kind) + assert.Equal(t, KindEndpoint, amb.Kind) }) } From 030a2790dd31086bbfc47eaeb21140987eed6246 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 10:51:59 +0200 Subject: [PATCH 04/25] Address review feedback round 2 - Fix selectAmbiguous: fall back to ID when DisplayName is empty. Round-1 fix to Choice semantics left producers emitting empty DisplayName for branches/endpoints/instances; the psql interactive selector passed that straight to cmdio.Tuple.Name and rendered blank rows. Add the documented fallback. - Drop unused BranchIDFromName / EndpointIDFromName exports; only ProjectIDFromName has callers in this PR. Re-add when first consumed. - Convert chained ifs in isRetryableConnectError to a switch. Co-authored-by: Isaac --- cmd/psql/psql_autoscaling.go | 11 ++++++++++- experimental/postgres/cmd/connect.go | 9 +++++---- libs/lakebase/target/target.go | 10 ---------- libs/lakebase/target/target_test.go | 6 ++---- 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/cmd/psql/psql_autoscaling.go b/cmd/psql/psql_autoscaling.go index a4c3293cc18..04ccd4bef6b 100644 --- a/cmd/psql/psql_autoscaling.go +++ b/cmd/psql/psql_autoscaling.go @@ -135,10 +135,19 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project // AmbiguousError. Caller is expected to have logged a header (e.g. via the // spinner) before invoking. Used to keep psql's interactive UX while letting // the shared lib do the actual list+filter work. +// +// Choice.DisplayName is empty when the producer has no friendlier label than +// the ID (e.g. branches and endpoints, where the ID is the human label). +// The promptui template renders an empty Name as a blank row, so we fall back +// to the ID before handing off to cmdio.SelectOrdered. func selectAmbiguous(ctx context.Context, amb *target.AmbiguousError, prompt string) (string, error) { items := make([]cmdio.Tuple, 0, len(amb.Choices)) for _, c := range amb.Choices { - items = append(items, cmdio.Tuple{Name: c.DisplayName, Id: c.ID}) + name := c.DisplayName + if name == "" { + name = c.ID + } + items = append(items, cmdio.Tuple{Name: name, Id: c.ID}) } return cmdio.SelectOrdered(ctx, items, prompt) } diff --git a/experimental/postgres/cmd/connect.go b/experimental/postgres/cmd/connect.go index 920f02e932f..2eefc681868 100644 --- a/experimental/postgres/cmd/connect.go +++ b/experimental/postgres/cmd/connect.go @@ -129,14 +129,15 @@ func isRetryableConnectError(err error) bool { var pgErr *pgconn.PgError if errors.As(err, &pgErr) { + switch { // 08xxx is the connection_exception class. - if len(pgErr.Code) == 5 && pgErr.Code[:2] == "08" { + case len(pgErr.Code) == 5 && pgErr.Code[:2] == "08": return true - } - if pgErr.Code == "57P03" { + case pgErr.Code == "57P03": return true + default: + return false } - return false } var connectErr *pgconn.ConnectError diff --git a/libs/lakebase/target/target.go b/libs/lakebase/target/target.go index f0fd2c069e3..1874829acce 100644 --- a/libs/lakebase/target/target.go +++ b/libs/lakebase/target/target.go @@ -164,16 +164,6 @@ func ProjectIDFromName(name string) string { return extractID(name, pathSegmentProjects) } -// BranchIDFromName extracts the branch ID from an SDK resource name. -func BranchIDFromName(name string) string { - return extractID(name, pathSegmentBranches) -} - -// EndpointIDFromName extracts the endpoint ID from an SDK resource name. -func EndpointIDFromName(name string) string { - return extractID(name, pathSegmentEndpoints) -} - // IsAutoscalingPath reports whether s is an autoscaling resource path // (i.e. starts with "projects/"). Provisioned instance names never do. func IsAutoscalingPath(s string) bool { diff --git a/libs/lakebase/target/target_test.go b/libs/lakebase/target/target_test.go index f502cf6e70c..f1726890330 100644 --- a/libs/lakebase/target/target_test.go +++ b/libs/lakebase/target/target_test.go @@ -90,12 +90,10 @@ func TestParseAutoscalingPath(t *testing.T) { } } -func TestIDFromName(t *testing.T) { +func TestProjectIDFromName(t *testing.T) { assert.Equal(t, "foo", ProjectIDFromName("projects/foo")) assert.Equal(t, "foo", ProjectIDFromName("projects/foo/branches/bar")) - assert.Equal(t, "bar", BranchIDFromName("projects/foo/branches/bar")) - assert.Equal(t, "bar", BranchIDFromName("projects/foo/branches/bar/endpoints/baz")) - assert.Equal(t, "baz", EndpointIDFromName("projects/foo/branches/bar/endpoints/baz")) + assert.Equal(t, "no-projects", ProjectIDFromName("no-projects")) } func TestIsAutoscalingPath(t *testing.T) { From 5e0e3dd2157ea324c819795b80c8ebf0dd79cd4d Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 11:01:07 +0200 Subject: [PATCH 05/25] Provisioned targeting + JSON/CSV streaming + typed values This is PR 2 of the experimental postgres query stack. Builds on PR 1's scaffold to fill in the rest of the single-input output story. Provisioned support: --target accepts both autoscaling resource paths (starts with "projects/") and provisioned instance names (everything else). Granular --project/--branch/--endpoint targeting stays autoscaling-only. resolveProvisioned validates the instance is in the AVAILABLE state and has read/write DNS before issuing a token. Output renderers are now sinks fed by executeOne row-by-row instead of buffering. textSink keeps buffering (tabwriter needs the widest cell to align); jsonSink and csvSink stream. jsonSink uses separator-before-element writing throughout so a mid-stream error can close the array cleanly via OnError, leaving stdout as parseable JSON with a partial result. JSON value rendering follows the typed mapping: numbers stay numeric inside +- 2^53, become strings outside; NaN/Inf become "NaN"/"Infinity"/ "-Infinity"; timestamps render in RFC3339; jsonb passes through as json.RawMessage so e.g. {"id": 9007199254740993} keeps its digits; bytea base64-encodes; everything else falls back to canonical Postgres text. CSV and text use Postgres' canonical text representation, with NULL rendered as the literal "NULL" in text and as empty in CSV (matches psql --csv). Output mode auto-selection mirrors aitools query: --output text on a non-TTY stdout falls back to JSON. DATABRICKS_OUTPUT_FORMAT is honoured when --output is not explicitly set; invalid env values are silently ignored. Duplicate column names are deterministically renamed (id, id__2, id__3) with a stderr warning. Acceptance: argument-errors loses the now-obsolete "provisioned not yet supported" case; new provisioned-targeting test exercises not-AVAILABLE / no-DNS / 404 paths via the SDK testserver mock. Co-authored-by: Isaac --- .../postgres/query/argument-errors/output.txt | 4 - .../postgres/query/argument-errors/script | 3 - .../query/provisioned-targeting/out.test.toml | 8 + .../query/provisioned-targeting/output.txt | 12 ++ .../query/provisioned-targeting/script | 8 + .../query/provisioned-targeting/test.toml | 30 +++ experimental/postgres/cmd/execute.go | 67 ++++--- experimental/postgres/cmd/output.go | 71 +++++++ experimental/postgres/cmd/output_test.go | 79 ++++++++ experimental/postgres/cmd/query.go | 70 +++++-- experimental/postgres/cmd/render.go | 77 ++++---- experimental/postgres/cmd/render_csv.go | 80 ++++++++ experimental/postgres/cmd/render_csv_test.go | 49 +++++ experimental/postgres/cmd/render_json.go | 173 ++++++++++++++++++ experimental/postgres/cmd/render_json_test.go | 118 ++++++++++++ experimental/postgres/cmd/render_test.go | 68 +++---- experimental/postgres/cmd/targeting.go | 47 ++++- experimental/postgres/cmd/value.go | 152 +++++++++++++++ experimental/postgres/cmd/value_test.go | 84 +++++++++ 19 files changed, 1079 insertions(+), 121 deletions(-) create mode 100644 acceptance/cmd/experimental/postgres/query/provisioned-targeting/out.test.toml create mode 100644 acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt create mode 100644 acceptance/cmd/experimental/postgres/query/provisioned-targeting/script create mode 100644 acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml create mode 100644 experimental/postgres/cmd/output.go create mode 100644 experimental/postgres/cmd/output_test.go create mode 100644 experimental/postgres/cmd/render_csv.go create mode 100644 experimental/postgres/cmd/render_csv_test.go create mode 100644 experimental/postgres/cmd/render_json.go create mode 100644 experimental/postgres/cmd/render_json_test.go create mode 100644 experimental/postgres/cmd/value.go create mode 100644 experimental/postgres/cmd/value_test.go diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt index c071466a1e3..238e099299c 100644 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt @@ -35,10 +35,6 @@ Error: --branch is required when using --endpoint >>> musterr [CLI] experimental postgres query --project foo --branch main --max-retries 0 SELECT 1 Error: --max-retries must be at least 1; got 0 -=== Provisioned-shaped target should error pointing at psql: ->>> musterr [CLI] experimental postgres query --target my-instance SELECT 1 -Error: provisioned instances are not yet supported by this experimental command; use 'databricks psql ' for now - === Malformed autoscaling path should error: >>> musterr [CLI] experimental postgres query --target projects/ SELECT 1 Error: invalid resource path: missing project ID diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/script b/acceptance/cmd/experimental/postgres/query/argument-errors/script index 8d64bf307ed..ac6ac42746e 100644 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/script +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/script @@ -25,9 +25,6 @@ trace musterr $CLI experimental postgres query --project foo --endpoint primary title "Max-retries 0 should error:" trace musterr $CLI experimental postgres query --project foo --branch main --max-retries 0 "SELECT 1" -title "Provisioned-shaped target should error pointing at psql:" -trace musterr $CLI experimental postgres query --target my-instance "SELECT 1" - title "Malformed autoscaling path should error:" trace musterr $CLI experimental postgres query --target projects/ "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/out.test.toml b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/out.test.toml new file mode 100644 index 00000000000..40bb0d10471 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/out.test.toml @@ -0,0 +1,8 @@ +Local = true +Cloud = false + +[GOOS] + windows = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt new file mode 100644 index 00000000000..0f00f8b3e44 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt @@ -0,0 +1,12 @@ + +=== Provisioned target in non-AVAILABLE state should error: +>>> musterr [CLI] experimental postgres query --target starting-instance SELECT 1 +Error: database instance "starting-instance" is not ready for accepting connections (state: STARTING) + +=== Provisioned target with no DNS should error: +>>> musterr [CLI] experimental postgres query --target no-dns-instance SELECT 1 +Error: database instance "no-dns-instance" has no read/write DNS yet + +=== Provisioned target not found should surface SDK 404: +>>> musterr [CLI] experimental postgres query --target missing-instance SELECT 1 +Error: failed to get database instance: instance not found diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script new file mode 100644 index 00000000000..d8995c62a6c --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script @@ -0,0 +1,8 @@ +title "Provisioned target in non-AVAILABLE state should error:" +trace musterr $CLI experimental postgres query --target starting-instance "SELECT 1" + +title "Provisioned target with no DNS should error:" +trace musterr $CLI experimental postgres query --target no-dns-instance "SELECT 1" + +title "Provisioned target not found should surface SDK 404:" +trace musterr $CLI experimental postgres query --target missing-instance "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml new file mode 100644 index 00000000000..4821dab5741 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml @@ -0,0 +1,30 @@ +GOOS.windows = false + +[[Server]] +Pattern = "GET /api/2.0/database/instances/starting-instance" +Response.Body = ''' +{ + "name": "starting-instance", + "state": "STARTING", + "read_write_dns": "starting.example.com" +} +''' + +[[Server]] +Pattern = "GET /api/2.0/database/instances/no-dns-instance" +Response.Body = ''' +{ + "name": "no-dns-instance", + "state": "AVAILABLE" +} +''' + +[[Server]] +Pattern = "GET /api/2.0/database/instances/missing-instance" +Response.StatusCode = 404 +Response.Body = ''' +{ + "error_code": "NOT_FOUND", + "message": "instance not found" +} +''' diff --git a/experimental/postgres/cmd/execute.go b/experimental/postgres/cmd/execute.go index c29f7ce59d6..61d93bd7bc2 100644 --- a/experimental/postgres/cmd/execute.go +++ b/experimental/postgres/cmd/execute.go @@ -5,10 +5,29 @@ import ( "fmt" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) -// executeOne runs a single SQL statement against an open connection and -// captures the result in a queryResult. +// rowSink consumes a query result one row at a time. Sinks that maintain open +// output structures (e.g. a streaming JSON array) implement OnError so they +// can close cleanly when the iteration terminates with a partial result. +type rowSink interface { + // Begin is called once with the column descriptions before any Row. + // For command-only statements (no rows), Begin is still called with an + // empty slice so the sink can lock in its rows-vs-command shape. + Begin(fields []pgconn.FieldDescription) error + // Row delivers one decoded row. Values aligns with the fields passed to + // Begin and uses pgx's Go type mapping (int64, float64, time.Time, ...). + Row(values []any) error + // End is called once on successful completion. + End(commandTag string) error + // OnError is called if iteration errors after Begin returned. The sink + // is expected to flush any in-progress output structures so stdout + // remains well-formed. The caller still surfaces err to its caller. + OnError(err error) +} + +// executeOne runs a single SQL statement and streams the result through sink. // // We pass QueryExecModeExec explicitly (not the pgx default // QueryExecModeCacheStatement) for two reasons: @@ -17,46 +36,38 @@ import ( // closed at the end of the command, so the cached prepared statement // never gets reused. // 2. Exec mode uses Postgres' extended-protocol "exec" path with text-format -// result columns. That makes rendering canonical-Postgres-text output -// (PR 1) and CSV (later PR) straightforward; the cache mode defaults to -// binary and we'd be reformatting back to text. +// result columns, which keeps the canonical-Postgres-text rendering for +// --output text and --output csv straightforward. // // QueryExecModeExec still uses extended protocol with a single statement and // no implicit transaction wrap, so transaction-disallowed DDL like -// `CREATE DATABASE` works. -func executeOne(ctx context.Context, conn *pgx.Conn, sql string) (*queryResult, error) { +// CREATE DATABASE works. +func executeOne(ctx context.Context, conn *pgx.Conn, sql string, sink rowSink) error { rows, err := conn.Query(ctx, sql, pgx.QueryExecModeExec) if err != nil { - return nil, fmt.Errorf("query failed: %w", err) + return fmt.Errorf("query failed: %w", err) } defer rows.Close() - result := &queryResult{SQL: sql} - - fields := rows.FieldDescriptions() - if len(fields) > 0 { - result.Columns = make([]string, len(fields)) - for i, f := range fields { - result.Columns[i] = f.Name - } + if err := sink.Begin(rows.FieldDescriptions()); err != nil { + return err } for rows.Next() { - raw := rows.RawValues() - row := make([]string, len(raw)) - for i, b := range raw { - if b == nil { - row[i] = "NULL" - continue - } - row[i] = string(b) + values, err := rows.Values() + if err != nil { + sink.OnError(err) + return fmt.Errorf("decode row: %w", err) + } + if err := sink.Row(values); err != nil { + sink.OnError(err) + return err } - result.Rows = append(result.Rows, row) } if err := rows.Err(); err != nil { - return nil, fmt.Errorf("query failed: %w", err) + sink.OnError(err) + return fmt.Errorf("query failed: %w", err) } - result.CommandTag = rows.CommandTag().String() - return result, nil + return sink.End(rows.CommandTag().String()) } diff --git a/experimental/postgres/cmd/output.go b/experimental/postgres/cmd/output.go new file mode 100644 index 00000000000..c293b424b73 --- /dev/null +++ b/experimental/postgres/cmd/output.go @@ -0,0 +1,71 @@ +package postgrescmd + +import ( + "context" + "fmt" + "strings" + + "github.com/databricks/cli/libs/env" +) + +// outputFormat is the user-selectable output shape. Using a string typedef +// instead of an int enum keeps the help text and DATABRICKS_OUTPUT_FORMAT env +// var values self-describing. +type outputFormat string + +const ( + outputText outputFormat = "text" + outputJSON outputFormat = "json" + outputCSV outputFormat = "csv" + + // envOutputFormat matches the env var name in cmd/root/io.go. Reading it + // here lets pipelines set DATABRICKS_OUTPUT_FORMAT once for all + // commands. See aitools query for a parallel pattern. + envOutputFormat = "DATABRICKS_OUTPUT_FORMAT" +) + +// allOutputFormats is the canonical order shown in completions / help. +var allOutputFormats = []outputFormat{outputText, outputJSON, outputCSV} + +// resolveOutputFormat picks the effective output format. Precedence: +// +// 1. The local --output flag if it was explicitly set. +// 2. DATABRICKS_OUTPUT_FORMAT env var if set to a known value (invalid +// values are silently ignored, matching cmd/root/io.go and aitools). +// 3. The flag default ("text"). +// +// Then the auto-selection rule applies: text on a non-TTY stdout falls back +// to JSON. This matches the aitools query command and means scripts piping +// stdout get machine-readable output by default. +// +// flagSet is true if the user explicitly passed --output. stdoutTTY is true +// if stdout is a terminal. +func resolveOutputFormat(ctx context.Context, flagValue string, flagSet, stdoutTTY bool) (outputFormat, error) { + chosen := outputFormat(strings.ToLower(flagValue)) + + if !flagSet { + if v, ok := env.Lookup(ctx, envOutputFormat); ok { + candidate := outputFormat(strings.ToLower(v)) + if isKnownOutputFormat(candidate) { + chosen = candidate + } + } + } + + if !isKnownOutputFormat(chosen) { + return "", fmt.Errorf("unsupported output format %q; expected one of: text, json, csv", flagValue) + } + + if chosen == outputText && !stdoutTTY { + return outputJSON, nil + } + return chosen, nil +} + +func isKnownOutputFormat(f outputFormat) bool { + switch f { + case outputText, outputJSON, outputCSV: + return true + } + return false +} diff --git a/experimental/postgres/cmd/output_test.go b/experimental/postgres/cmd/output_test.go new file mode 100644 index 00000000000..79289a43e56 --- /dev/null +++ b/experimental/postgres/cmd/output_test.go @@ -0,0 +1,79 @@ +package postgrescmd + +import ( + "testing" + + "github.com/databricks/cli/libs/env" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestResolveOutputFormat_Defaults(t *testing.T) { + ctx := t.Context() + + got, err := resolveOutputFormat(ctx, "text", false, true) + require.NoError(t, err) + assert.Equal(t, outputText, got) +} + +func TestResolveOutputFormat_TextOnPipeFallsBackToJSON(t *testing.T) { + ctx := t.Context() + got, err := resolveOutputFormat(ctx, "text", false, false) + require.NoError(t, err) + assert.Equal(t, outputJSON, got) +} + +func TestResolveOutputFormat_ExplicitTextOnPipeAlsoFallsBackToJSON(t *testing.T) { + ctx := t.Context() + got, err := resolveOutputFormat(ctx, "text", true, false) + require.NoError(t, err) + assert.Equal(t, outputJSON, got) +} + +func TestResolveOutputFormat_ExplicitJSON(t *testing.T) { + ctx := t.Context() + got, err := resolveOutputFormat(ctx, "json", true, true) + require.NoError(t, err) + assert.Equal(t, outputJSON, got) +} + +func TestResolveOutputFormat_ExplicitCSV(t *testing.T) { + ctx := t.Context() + got, err := resolveOutputFormat(ctx, "csv", true, true) + require.NoError(t, err) + assert.Equal(t, outputCSV, got) +} + +func TestResolveOutputFormat_EnvVarHonoredWhenFlagNotSet(t *testing.T) { + ctx := env.Set(t.Context(), envOutputFormat, "csv") + got, err := resolveOutputFormat(ctx, "text", false, true) + require.NoError(t, err) + assert.Equal(t, outputCSV, got) +} + +func TestResolveOutputFormat_FlagOverridesEnvVar(t *testing.T) { + ctx := env.Set(t.Context(), envOutputFormat, "csv") + got, err := resolveOutputFormat(ctx, "json", true, true) + require.NoError(t, err) + assert.Equal(t, outputJSON, got) +} + +func TestResolveOutputFormat_InvalidEnvVarIgnored(t *testing.T) { + ctx := env.Set(t.Context(), envOutputFormat, "yaml") + got, err := resolveOutputFormat(ctx, "text", false, true) + require.NoError(t, err) + assert.Equal(t, outputText, got) +} + +func TestResolveOutputFormat_InvalidFlagErrors(t *testing.T) { + ctx := t.Context() + _, err := resolveOutputFormat(ctx, "yaml", true, true) + assert.ErrorContains(t, err, "unsupported output format") +} + +func TestResolveOutputFormat_CaseInsensitive(t *testing.T) { + ctx := t.Context() + got, err := resolveOutputFormat(ctx, "JSON", true, true) + require.NoError(t, err) + assert.Equal(t, outputJSON, got) +} diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index fe5cc528ea7..c3078f24d82 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "strings" "time" @@ -25,6 +26,11 @@ type queryFlags struct { database string connectTimeout time.Duration maxRetries int + + // outputFormat is the raw flag value. resolveOutputFormat turns it into + // the effective format (which may differ when stdout is piped). + outputFormat string + outputFormatSet bool } func newQueryCmd() *cobra.Command { @@ -33,15 +39,29 @@ func newQueryCmd() *cobra.Command { cmd := &cobra.Command{ Use: "query [SQL]", Short: "Run a SQL statement against a Lakebase Postgres endpoint", - Long: `Execute a single SQL statement against a Lakebase Postgres endpoint and -render the result as text. + Long: `Execute a single SQL statement against a Lakebase Postgres endpoint. Targeting (exactly one form required): - --target STRING Autoscaling resource path - (e.g. projects/foo/branches/main/endpoints/primary) + --target STRING Provisioned instance name OR autoscaling resource path + (e.g. my-instance, projects/foo/branches/main/endpoints/primary) --project ID Autoscaling project ID --branch ID Autoscaling branch ID (default: auto-select if exactly one) - --endpoint ID Autoscaling endpoint ID (default: auto-select if exactly one) + --endpoint ID Autoscaling endpoint ID + +Output: + --output text Aligned table for rows-producing statements (default). + Falls back to JSON when stdout is not a terminal so + scripts piping the output get machine-readable results. + --output json Top-level array of row objects, streamed for + rows-producing statements. Command-only statements + emit a single {"command": "...", "rows_affected": N} + object. Numbers, booleans, NULL, jsonb, timestamps + render with their JSON-native types. + --output csv Header row + one CSV row per result row, streamed. + Command-only statements write the command tag to + stderr. + +DATABRICKS_OUTPUT_FORMAT is honoured when --output is not explicitly set. This is an experimental command. The flag set, output shape, and supported target kinds will expand in subsequent releases. @@ -49,10 +69,6 @@ target kinds will expand in subsequent releases. Limitations (this release): - Single SQL statement per invocation (multi-statement support comes later). - - Text output only. JSON and CSV output come in a follow-up release. - - Only Lakebase Autoscaling endpoints are supported. Provisioned instance - support comes in a follow-up release; use 'databricks psql ' as a - workaround for now. - No interactive REPL. 'databricks psql' continues to own that surface. - Multi-statement strings (e.g. "SELECT 1; SELECT 2") are not supported. - The OAuth token is generated once per invocation and is valid for 1h. @@ -61,17 +77,26 @@ Limitations (this release): Args: cobra.ExactArgs(1), PreRunE: root.MustWorkspaceClient, RunE: func(cmd *cobra.Command, args []string) error { + f.outputFormatSet = cmd.Flag("output").Changed return runQuery(cmd.Context(), cmd, args[0], f) }, } - cmd.Flags().StringVar(&f.target, "target", "", "Autoscaling resource path (e.g. projects/foo/branches/main/endpoints/primary)") + cmd.Flags().StringVar(&f.target, "target", "", "Provisioned instance name OR autoscaling resource path") cmd.Flags().StringVar(&f.project, "project", "", "Autoscaling project ID") cmd.Flags().StringVar(&f.branch, "branch", "", "Autoscaling branch ID (default: auto-select if exactly one)") cmd.Flags().StringVar(&f.endpoint, "endpoint", "", "Autoscaling endpoint ID (default: auto-select if exactly one)") cmd.Flags().StringVarP(&f.database, "database", "d", defaultDatabase, "Database name") cmd.Flags().DurationVar(&f.connectTimeout, "connect-timeout", defaultConnectTimeout, "Connect timeout") cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (must be >= 1; 1 disables retry)") + cmd.Flags().StringVarP(&f.outputFormat, "output", "o", string(outputText), "Output format: text, json, or csv") + cmd.RegisterFlagCompletionFunc("output", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { + out := make([]string, len(allOutputFormats)) + for i, f := range allOutputFormats { + out[i] = string(f) + } + return out, cobra.ShellCompDirectiveNoFileComp + }) cmd.MarkFlagsMutuallyExclusive("target", "project") cmd.MarkFlagsMutuallyExclusive("target", "branch") @@ -95,6 +120,12 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) return err } + stdoutTTY := cmdio.SupportsColor(ctx, cmd.OutOrStdout()) + format, err := resolveOutputFormat(ctx, f.outputFormat, f.outputFormatSet, stdoutTTY) + if err != nil { + return err + } + resolved, err := resolveTarget(ctx, f.targetingFlags) if err != nil { return err @@ -126,10 +157,19 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) } defer conn.Close(context.WithoutCancel(ctx)) - result, err := executeOne(ctx, conn, sql) - if err != nil { - return err - } + sink := newSink(format, cmd.OutOrStdout(), cmd.ErrOrStderr()) + return executeOne(ctx, conn, sql, sink) +} - return renderText(cmd.OutOrStdout(), result) +// newSink returns the rowSink for the chosen output format. Kept separate +// from runQuery so tests can build sinks without going through pgx. +func newSink(format outputFormat, out, stderr io.Writer) rowSink { + switch format { + case outputJSON: + return newJSONSink(out, stderr) + case outputCSV: + return newCSVSink(out, stderr) + default: + return newTextSink(out) + } } diff --git a/experimental/postgres/cmd/render.go b/experimental/postgres/cmd/render.go index ff923c4a92e..bc45c89e0d0 100644 --- a/experimental/postgres/cmd/render.go +++ b/experimental/postgres/cmd/render.go @@ -5,59 +5,68 @@ import ( "io" "strings" "text/tabwriter" + + "github.com/jackc/pgx/v5/pgconn" ) -// queryResult is the rendered shape of a single SQL execution. PR 1 only -// renders text; later PRs add JSON and CSV against the same struct. +// textSink renders results as plain text: a tabwriter-aligned table for +// rows-producing statements, the command tag for command-only ones. // -// columns is empty for command-only statements (INSERT, CREATE DATABASE, ...); -// rows is empty when no rows were returned (or for command-only statements). -type queryResult struct { - SQL string - // CommandTag is the Postgres command tag for the statement (e.g. "INSERT 0 5", - // "CREATE DATABASE"). Always set; used for command-only statements and as a - // trailer for rows-producing ones. - CommandTag string - Columns []string - Rows [][]string +// Text output buffers all rows because tabwriter needs the widest cell in each +// column before it can align. Streaming output is provided by the JSON and CSV +// sinks; users with huge result sets should pick those. +type textSink struct { + out io.Writer + columns []string + rows [][]string } -// IsRowsProducing reports whether the statement returned a row description. -// Determined at runtime via FieldDescriptions() rather than by parsing the -// leading SQL keyword: `INSERT ... RETURNING` and CTEs ending in a SELECT are -// rows-producing despite their leading keywords. -func (r *queryResult) IsRowsProducing() bool { - return len(r.Columns) > 0 +func newTextSink(out io.Writer) *textSink { + return &textSink{out: out} } -// renderText writes a result in plain text. -// -// For rows-producing statements we use a tabwriter-aligned table followed by -// a `(N rows)` footer, mimicking psql's compact text shape. For command-only -// statements we just print the command tag. -// -// PR 1 always uses the static (buffered) shape. The interactive table viewer -// for >30 rows lands in a later PR alongside the multi-input output shapes. -func renderText(out io.Writer, r *queryResult) error { - if !r.IsRowsProducing() { - _, err := fmt.Fprintln(out, r.CommandTag) +func (s *textSink) Begin(fields []pgconn.FieldDescription) error { + s.columns = make([]string, len(fields)) + for i, f := range fields { + s.columns[i] = f.Name + } + return nil +} + +func (s *textSink) Row(values []any) error { + row := make([]string, len(values)) + for i, v := range values { + row[i] = textValue(v) + } + s.rows = append(s.rows, row) + return nil +} + +func (s *textSink) End(commandTag string) error { + if len(s.columns) == 0 { + _, err := fmt.Fprintln(s.out, commandTag) return err } - tw := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0) - fmt.Fprintln(tw, strings.Join(r.Columns, "\t")) - fmt.Fprintln(tw, strings.Join(headerSeparator(r.Columns), "\t")) - for _, row := range r.Rows { + tw := tabwriter.NewWriter(s.out, 0, 0, 2, ' ', 0) + fmt.Fprintln(tw, strings.Join(s.columns, "\t")) + fmt.Fprintln(tw, strings.Join(headerSeparator(s.columns), "\t")) + for _, row := range s.rows { fmt.Fprintln(tw, strings.Join(row, "\t")) } if err := tw.Flush(); err != nil { return err } - _, err := fmt.Fprintf(out, "(%d %s)\n", len(r.Rows), pluralize(len(r.Rows), "row", "rows")) + _, err := fmt.Fprintf(s.out, "(%d %s)\n", len(s.rows), pluralize(len(s.rows), "row", "rows")) return err } +// OnError for text sinks is a no-op: text output prints whatever rows have +// already been collected, with no open structure to close. The caller +// surfaces the error separately (cobra's default error rendering). +func (s *textSink) OnError(err error) {} + func headerSeparator(cols []string) []string { out := make([]string, len(cols)) for i, c := range cols { diff --git a/experimental/postgres/cmd/render_csv.go b/experimental/postgres/cmd/render_csv.go new file mode 100644 index 00000000000..940e11324f5 --- /dev/null +++ b/experimental/postgres/cmd/render_csv.go @@ -0,0 +1,80 @@ +package postgrescmd + +import ( + "encoding/csv" + "fmt" + "io" + + "github.com/jackc/pgx/v5/pgconn" +) + +// csvSink streams query results as CSV. Header row is written on Begin, each +// data row is written and flushed individually so large exports do not buffer +// in memory. +// +// For command-only statements CSV has nothing meaningful to emit (no header, +// no rows): we write the command tag to stderr so machine consumers reading +// stdout still receive an empty document, while humans get a confirmation. +type csvSink struct { + out io.Writer + stderr io.Writer + w *csv.Writer + + // rowsProducing is true once Begin saw a non-empty fields slice. End + // uses it to decide whether to write the command-tag stderr line. + rowsProducing bool +} + +func newCSVSink(out, stderr io.Writer) *csvSink { + return &csvSink{out: out, stderr: stderr, w: csv.NewWriter(out)} +} + +func (s *csvSink) Begin(fields []pgconn.FieldDescription) error { + if len(fields) == 0 { + return nil + } + s.rowsProducing = true + + header := make([]string, len(fields)) + for i, f := range fields { + header[i] = f.Name + } + if err := s.w.Write(header); err != nil { + return fmt.Errorf("write CSV header: %w", err) + } + s.w.Flush() + return s.w.Error() +} + +func (s *csvSink) Row(values []any) error { + row := make([]string, len(values)) + for i, v := range values { + // CSV represents NULL as an empty field, matching `psql --csv`. + if v == nil { + row[i] = "" + continue + } + row[i] = textValue(v) + } + if err := s.w.Write(row); err != nil { + return fmt.Errorf("write CSV row: %w", err) + } + s.w.Flush() + return s.w.Error() +} + +func (s *csvSink) End(commandTag string) error { + if !s.rowsProducing { + _, err := fmt.Fprintln(s.stderr, commandTag) + return err + } + s.w.Flush() + return s.w.Error() +} + +// OnError flushes whatever is buffered in the csv.Writer so the partial result +// is visible to the consumer. csv.Writer has no concept of "open structure", +// so there is nothing more to do. +func (s *csvSink) OnError(err error) { + s.w.Flush() +} diff --git a/experimental/postgres/cmd/render_csv_test.go b/experimental/postgres/cmd/render_csv_test.go new file mode 100644 index 00000000000..35d1c3596f1 --- /dev/null +++ b/experimental/postgres/cmd/render_csv_test.go @@ -0,0 +1,49 @@ +package postgrescmd + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCSVSink_TwoRows(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newCSVSink(&stdout, &stderr) + require.NoError(t, s.Begin(fields("id", "name"))) + require.NoError(t, s.Row([]any{int64(1), "alice"})) + require.NoError(t, s.Row([]any{int64(2), "bob"})) + require.NoError(t, s.End("SELECT 2")) + + assert.Equal(t, "id,name\n1,alice\n2,bob\n", stdout.String()) + assert.Empty(t, stderr.String()) +} + +func TestCSVSink_NULLEmptyField(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newCSVSink(&stdout, &stderr) + require.NoError(t, s.Begin(fields("id", "note"))) + require.NoError(t, s.Row([]any{int64(1), nil})) + require.NoError(t, s.End("SELECT 1")) + + assert.Equal(t, "id,note\n1,\n", stdout.String()) +} + +func TestCSVSink_CommandOnly(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newCSVSink(&stdout, &stderr) + require.NoError(t, s.Begin(nil)) + require.NoError(t, s.End("CREATE DATABASE")) + assert.Empty(t, stdout.String()) + assert.Equal(t, "CREATE DATABASE\n", stderr.String()) +} + +func TestCSVSink_QuotesFieldsWithCommas(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newCSVSink(&stdout, &stderr) + require.NoError(t, s.Begin(fields("note"))) + require.NoError(t, s.Row([]any{"a,b"})) + require.NoError(t, s.End("SELECT 1")) + assert.Contains(t, stdout.String(), `"a,b"`) +} diff --git a/experimental/postgres/cmd/render_json.go b/experimental/postgres/cmd/render_json.go new file mode 100644 index 00000000000..1d9a53a8e8d --- /dev/null +++ b/experimental/postgres/cmd/render_json.go @@ -0,0 +1,173 @@ +package postgrescmd + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "strconv" + + "github.com/jackc/pgx/v5/pgconn" +) + +// jsonSink streams query results as JSON. +// +// For rows-producing statements the output is a top-level array of row +// objects. We use the separator-before-element pattern to avoid the +// "rewrite the trailing comma" trick and keep the JSON parseable even when +// iteration ends with a partial result (caller closes the array on OnError). +// +// For command-only statements the output is a single object describing the +// command tag. +type jsonSink struct { + out io.Writer + stderr io.Writer + + // columns are the disambiguated column names: duplicates beyond the first + // occurrence are renamed to "__2", "__3", etc. Postgres + // allows duplicate output names (`SELECT 1, 1`, joins with two unaliased + // `id` columns) but JSON consumers usually want unique keys; we rename + // deterministically and warn once on stderr. + columns []string + oids []uint32 + + // hasOpenedArray is true once the leading `[\n` has been written. Used + // by OnError to decide whether to emit the closing `]\n` to keep stdout + // well-formed. + hasOpenedArray bool + // rowsWritten counts emitted rows so the separator decision is trivial: + // 0 means "first row, no separator", anything else means "separator first". + rowsWritten int +} + +func newJSONSink(out, stderr io.Writer) *jsonSink { + return &jsonSink{out: out, stderr: stderr} +} + +func (s *jsonSink) Begin(fields []pgconn.FieldDescription) error { + if len(fields) == 0 { + // Command-only; we wait until End to emit the {"command": ...} object. + return nil + } + + s.columns = make([]string, len(fields)) + s.oids = make([]uint32, len(fields)) + seen := make(map[string]int, len(fields)) + dupes := false + for i, f := range fields { + s.oids[i] = f.DataTypeOID + name := f.Name + seen[name]++ + if seen[name] > 1 { + dupes = true + name = fmt.Sprintf("%s__%d", f.Name, seen[name]) + } + s.columns[i] = name + } + if dupes { + fmt.Fprintln(s.stderr, "Warning: query returned duplicate column names; renamed duplicates to __N. Use AS aliases for stable names.") + } + + if _, err := io.WriteString(s.out, "[\n"); err != nil { + return err + } + s.hasOpenedArray = true + return nil +} + +func (s *jsonSink) Row(values []any) error { + if s.rowsWritten > 0 { + if _, err := io.WriteString(s.out, ",\n"); err != nil { + return err + } + } + + // Build the row object as a *map* of column to converted value, then let + // json.Marshal handle the encoding. We don't preserve key insertion order + // (json package sorts map keys), which is fine for machine consumers; the + // columns slice is the canonical order. + // + // Using ordered emission would require a manual writer. Worth the cost + // only if a downstream consumer needs schema-positional output, which + // none do today. + obj := make(map[string]any, len(s.columns)) + for i, name := range s.columns { + obj[name] = jsonValueWithOID(values[i], s.oids[i]) + } + + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + if err := enc.Encode(obj); err != nil { + return fmt.Errorf("encode row: %w", err) + } + // json.Encoder always writes a trailing newline; trim it so our outer + // formatting controls the layout. + out := bytes.TrimRight(buf.Bytes(), "\n") + if _, err := s.out.Write(out); err != nil { + return err + } + s.rowsWritten++ + return nil +} + +func (s *jsonSink) End(commandTag string) error { + if s.hasOpenedArray { + _, err := io.WriteString(s.out, "\n]\n") + return err + } + // Command-only path: emit a single object. + obj := map[string]any{"command": commandTagVerb(commandTag)} + if rows, ok := commandTagRowCount(commandTag); ok { + obj["rows_affected"] = rows + } + + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + if err := enc.Encode(obj); err != nil { + return fmt.Errorf("encode command tag: %w", err) + } + _, err := s.out.Write(buf.Bytes()) + return err +} + +// OnError closes the array cleanly so stdout remains parseable JSON. The +// caller still propagates the original error, which the command writes to +// stderr. +func (s *jsonSink) OnError(err error) { + if !s.hasOpenedArray { + return + } + // Best-effort; if this Write fails the stream is already corrupted + // and there is nothing more we can do. + _, _ = io.WriteString(s.out, "\n]\n") +} + +// commandTagVerb extracts the leading verb from a Postgres command tag (e.g. +// "INSERT 0 5" -> "INSERT"). Returns the input unchanged if there is no space. +func commandTagVerb(tag string) string { + for i, r := range tag { + if r == ' ' { + return tag[:i] + } + } + return tag +} + +// commandTagRowCount extracts the trailing row count from a Postgres command +// tag. INSERT tags have the shape "INSERT "; UPDATE/DELETE/SELECT +// have "VERB ". Returns ok=false for tags without a trailing integer +// (e.g. "CREATE DATABASE", "SET"). +func commandTagRowCount(tag string) (int64, bool) { + for i := len(tag) - 1; i >= 0; i-- { + if tag[i] == ' ' { + n, err := strconv.ParseInt(tag[i+1:], 10, 64) + if err != nil { + return 0, false + } + return n, true + } + } + return 0, false +} diff --git a/experimental/postgres/cmd/render_json_test.go b/experimental/postgres/cmd/render_json_test.go new file mode 100644 index 00000000000..a2617b27bc6 --- /dev/null +++ b/experimental/postgres/cmd/render_json_test.go @@ -0,0 +1,118 @@ +package postgrescmd + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func fieldsWithOIDs(names []string, oids []uint32) []pgconn.FieldDescription { + out := make([]pgconn.FieldDescription, len(names)) + for i, n := range names { + out[i] = pgconn.FieldDescription{Name: n, DataTypeOID: oids[i]} + } + return out +} + +func TestJSONSink_TwoRows(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + + require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id", "name"}, []uint32{pgtype.Int8OID, pgtype.TextOID}))) + require.NoError(t, s.Row([]any{int64(1), "alice"})) + require.NoError(t, s.Row([]any{int64(2), "bob"})) + require.NoError(t, s.End("SELECT 2")) + + assert.Equal(t, + "[\n"+ + `{"id":1,"name":"alice"}`+",\n"+ + `{"id":2,"name":"bob"}`+ + "\n]\n", + stdout.String(), + ) + assert.Empty(t, stderr.String()) +} + +func TestJSONSink_EmptyRowsProducing(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}))) + require.NoError(t, s.End("SELECT 0")) + assert.Equal(t, "[\n\n]\n", stdout.String()) +} + +func TestJSONSink_CommandOnly_WithRowCount(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(nil)) + require.NoError(t, s.End("INSERT 0 5")) + assert.JSONEq(t, `{"command":"INSERT","rows_affected":5}`, stdout.String()) +} + +func TestJSONSink_CommandOnly_NoRowCount(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(nil)) + require.NoError(t, s.End("CREATE DATABASE")) + assert.JSONEq(t, `{"command":"CREATE"}`, stdout.String()) +} + +func TestJSONSink_DuplicateColumns(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id", "id", "id"}, []uint32{pgtype.Int8OID, pgtype.Int8OID, pgtype.Int8OID}))) + require.NoError(t, s.Row([]any{int64(1), int64(2), int64(3)})) + require.NoError(t, s.End("SELECT 1")) + + assert.Contains(t, stdout.String(), `"id":1`) + assert.Contains(t, stdout.String(), `"id__2":2`) + assert.Contains(t, stdout.String(), `"id__3":3`) + assert.Contains(t, stderr.String(), "duplicate column names") +} + +func TestJSONSink_OnError_AfterRows(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}))) + require.NoError(t, s.Row([]any{int64(1)})) + s.OnError(assert.AnError) + + assert.Contains(t, stdout.String(), "[\n") + assert.Contains(t, stdout.String(), `{"id":1}`) + assert.Contains(t, stdout.String(), "\n]\n") +} + +func TestJSONSink_OnError_BeforeBegin(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + s.OnError(assert.AnError) + assert.Empty(t, stdout.String()) +} + +func TestCommandTagParse(t *testing.T) { + tests := []struct { + tag string + verb string + rows int64 + hasCount bool + }{ + {"INSERT 0 5", "INSERT", 5, true}, + {"UPDATE 3", "UPDATE", 3, true}, + {"DELETE 0", "DELETE", 0, true}, + {"SELECT 100", "SELECT", 100, true}, + {"CREATE DATABASE", "CREATE", 0, false}, + {"SET", "SET", 0, false}, + } + for _, tc := range tests { + assert.Equal(t, tc.verb, commandTagVerb(tc.tag), "verb for %q", tc.tag) + count, ok := commandTagRowCount(tc.tag) + assert.Equal(t, tc.hasCount, ok, "hasCount for %q", tc.tag) + if tc.hasCount { + assert.Equal(t, tc.rows, count, "rows for %q", tc.tag) + } + } +} diff --git a/experimental/postgres/cmd/render_test.go b/experimental/postgres/cmd/render_test.go index 29aeb3c36fc..06190323e43 100644 --- a/experimental/postgres/cmd/render_test.go +++ b/experimental/postgres/cmd/render_test.go @@ -4,21 +4,29 @@ import ( "bytes" "testing" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestRenderText_RowsProducing(t *testing.T) { - r := &queryResult{ - Columns: []string{"id", "name"}, - Rows: [][]string{ - {"1", "alice"}, - {"2", "bob"}, - }, - CommandTag: "SELECT 2", +// fields is a small helper to build []pgconn.FieldDescription with just names +// (no OIDs), so renderer tests don't need to know about Postgres OIDs. +func fields(names ...string) []pgconn.FieldDescription { + out := make([]pgconn.FieldDescription, len(names)) + for i, n := range names { + out[i] = pgconn.FieldDescription{Name: n} } + return out +} + +func TestTextSink_RowsProducing(t *testing.T) { var buf bytes.Buffer - require.NoError(t, renderText(&buf, r)) + s := newTextSink(&buf) + + require.NoError(t, s.Begin(fields("id", "name"))) + require.NoError(t, s.Row([]any{int64(1), "alice"})) + require.NoError(t, s.Row([]any{int64(2), "bob"})) + require.NoError(t, s.End("SELECT 2")) assert.Equal(t, "id name\n"+ @@ -30,38 +38,36 @@ func TestRenderText_RowsProducing(t *testing.T) { ) } -func TestRenderText_SingleRow(t *testing.T) { - r := &queryResult{ - Columns: []string{"id"}, - Rows: [][]string{{"42"}}, - CommandTag: "SELECT 1", - } +func TestTextSink_SingleRow(t *testing.T) { var buf bytes.Buffer - require.NoError(t, renderText(&buf, r)) + s := newTextSink(&buf) + require.NoError(t, s.Begin(fields("id"))) + require.NoError(t, s.Row([]any{int64(42)})) + require.NoError(t, s.End("SELECT 1")) assert.Contains(t, buf.String(), "(1 row)\n") } -func TestRenderText_Empty(t *testing.T) { - r := &queryResult{ - Columns: []string{"id", "name"}, - CommandTag: "SELECT 0", - } +func TestTextSink_Empty(t *testing.T) { var buf bytes.Buffer - require.NoError(t, renderText(&buf, r)) + s := newTextSink(&buf) + require.NoError(t, s.Begin(fields("id", "name"))) + require.NoError(t, s.End("SELECT 0")) assert.Contains(t, buf.String(), "(0 rows)\n") } -func TestRenderText_CommandOnly(t *testing.T) { - r := &queryResult{ - CommandTag: "INSERT 0 5", - } +func TestTextSink_CommandOnly(t *testing.T) { var buf bytes.Buffer - require.NoError(t, renderText(&buf, r)) + s := newTextSink(&buf) + require.NoError(t, s.Begin(nil)) + require.NoError(t, s.End("INSERT 0 5")) assert.Equal(t, "INSERT 0 5\n", buf.String()) } -func TestQueryResultIsRowsProducing(t *testing.T) { - assert.False(t, (&queryResult{}).IsRowsProducing()) - assert.False(t, (&queryResult{CommandTag: "INSERT 0 1"}).IsRowsProducing()) - assert.True(t, (&queryResult{Columns: []string{"a"}}).IsRowsProducing()) +func TestTextSink_NULLRendersAsNULL(t *testing.T) { + var buf bytes.Buffer + s := newTextSink(&buf) + require.NoError(t, s.Begin(fields("id"))) + require.NoError(t, s.Row([]any{nil})) + require.NoError(t, s.End("SELECT 1")) + assert.Contains(t, buf.String(), "NULL") } diff --git a/experimental/postgres/cmd/targeting.go b/experimental/postgres/cmd/targeting.go index 5e72840f952..78e230adaac 100644 --- a/experimental/postgres/cmd/targeting.go +++ b/experimental/postgres/cmd/targeting.go @@ -8,6 +8,7 @@ import ( "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/lakebase/target" "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/database" "github.com/databricks/databricks-sdk-go/service/postgres" ) @@ -71,9 +72,10 @@ func validateTargeting(f targetingFlags) error { } // resolveTarget translates the validated flags into a resolvedTarget. -// PR 1 supports autoscaling targeting only; provisioned support is added in -// the next PR. A provisioned-shaped --target returns a clear error pointing at -// the experimental status. +// +// --target accepts either an autoscaling resource path (starts with "projects/") +// or a provisioned instance name (everything else). Granular flags +// (--project, --branch, --endpoint) target autoscaling only. func resolveTarget(ctx context.Context, f targetingFlags) (*resolvedTarget, error) { w := cmdctx.WorkspaceClient(ctx) @@ -86,9 +88,7 @@ func resolveTarget(ctx context.Context, f targetingFlags) (*resolvedTarget, erro return resolveAutoscaling(ctx, w, spec) case f.target != "": - // Provisioned-shaped target. Out of scope for this PR; will be wired in - // the follow-up PR alongside JSON/CSV output. - return nil, errors.New("provisioned instances are not yet supported by this experimental command; use 'databricks psql ' for now") + return resolveProvisioned(ctx, w, f.target) default: spec := target.AutoscalingSpec{ @@ -100,6 +100,41 @@ func resolveTarget(ctx context.Context, f targetingFlags) (*resolvedTarget, erro } } +// resolveProvisioned looks up a provisioned instance and issues a token. The +// instance must be in the AVAILABLE state; transitional states return an +// error pointing the user at the lifecycle they are waiting on. +func resolveProvisioned(ctx context.Context, w *databricks.WorkspaceClient, instanceName string) (*resolvedTarget, error) { + instance, err := target.GetProvisioned(ctx, w, instanceName) + if err != nil { + return nil, err + } + + if instance.State != database.DatabaseInstanceStateAvailable { + return nil, fmt.Errorf("database instance %q is not ready for accepting connections (state: %s)", instance.Name, instance.State) + } + if instance.ReadWriteDns == "" { + return nil, fmt.Errorf("database instance %q has no read/write DNS yet", instance.Name) + } + + user, err := w.CurrentUser.Me(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get current user: %w", err) + } + + token, err := target.ProvisionedCredential(ctx, w, instance.Name) + if err != nil { + return nil, err + } + + return &resolvedTarget{ + Kind: kindProvisioned, + Host: instance.ReadWriteDns, + Username: user.UserName, + Token: token, + DisplayName: instance.Name, + }, nil +} + // resolveAutoscaling expands a partial spec into a fully-resolved endpoint and // issues a short-lived OAuth token. Missing branch/endpoint IDs are // auto-selected when exactly one candidate exists; ambiguity propagates as an diff --git a/experimental/postgres/cmd/value.go b/experimental/postgres/cmd/value.go new file mode 100644 index 00000000000..3049b44a82a --- /dev/null +++ b/experimental/postgres/cmd/value.go @@ -0,0 +1,152 @@ +package postgrescmd + +import ( + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "math" + "math/big" + "strconv" + "time" + + "github.com/jackc/pgx/v5/pgtype" +) + +// safeIntegerBound is the largest absolute integer value that can be +// represented exactly in IEEE 754 double precision. Beyond this, encoding an +// int64 as a JSON number silently loses precision in JavaScript-style +// consumers. We render those as JSON strings to preserve the original digits. +const safeIntegerBound = 1<<53 - 1 + +// textValue renders a Go value (as decoded by pgx) to its canonical Postgres +// text representation. Used by --output text and --output csv. +// +// NULL renders as the literal "NULL" so it lines up with the column rather +// than appearing as an empty cell. CSV converts that back to an empty field +// at write time (matches `psql --csv`). +func textValue(v any) string { + if v == nil { + return "NULL" + } + + switch x := v.(type) { + case string: + return x + case []byte: + return `\x` + hex.EncodeToString(x) + case bool: + if x { + return "t" + } + return "f" + case time.Time: + return x.Format(time.RFC3339Nano) + case fmt.Stringer: + return x.String() + } + + return fmt.Sprintf("%v", v) +} + +// jsonValue renders a Go value (as decoded by pgx) to a JSON-encodable +// representation. Returns a value the standard json.Marshal can handle +// directly and the JSON shape we want; never returns Go values that would +// silently lose information (e.g. NaN, oversized integers). +// +// The mapping intentionally favours machine-friendly output: +// - jsonb / json bytes round-trip as raw JSON (preserves bigint precision +// inside JSON values, e.g. {"id": 9007199254740993}). +// - bytea encodes as base64. +// - timestamps render in RFC3339 with subsecond precision. +// - Postgres NaN / +Inf / -Inf become JSON strings (JSON has no IEEE-special). +// - Integers outside ±2^53 become JSON strings to preserve precision. +// - Numerics, intervals, geometric types, and unknown types fall back to +// the canonical Postgres text representation as a JSON string. +func jsonValue(v any) any { + if v == nil { + return nil + } + + switch x := v.(type) { + case bool: + return x + case string: + return x + case int8, int16, int32, int, uint8, uint16, uint32: + return x + case int64: + if x > safeIntegerBound || x < -safeIntegerBound { + return strconv.FormatInt(x, 10) + } + return x + case uint64: + if x > safeIntegerBound { + return strconv.FormatUint(x, 10) + } + return x + case float32: + return jsonFloat(float64(x)) + case float64: + return jsonFloat(x) + case []byte: + // Postgres jsonb / json arrive as []byte holding raw JSON. Anything + // else we'd like to base64-encode. We can't tell them apart from the + // Go type alone; the sink calls jsonValueWithOID for oid-aware + // disambiguation. This bare path is the conservative fallback and + // treats unknown bytes as base64 (lossless and correct for bytea). + return base64.StdEncoding.EncodeToString(x) + case time.Time: + return x.UTC().Format(time.RFC3339Nano) + case *big.Int: + // numeric without scale; preserve as string to keep precision. + return x.String() + case fmt.Stringer: + return x.String() + } + + return fmt.Sprintf("%v", v) +} + +// jsonFloat handles the IEEE-special cases that JSON cannot represent. +// Finite values pass through unchanged. +func jsonFloat(f float64) any { + switch { + case math.IsNaN(f): + return "NaN" + case math.IsInf(f, 1): + return "Infinity" + case math.IsInf(f, -1): + return "-Infinity" + } + return f +} + +// jsonValueWithOID applies oid-aware overrides on top of jsonValue. The two +// places this matters today are JSON/JSONB and bytea: both arrive from pgx as +// []byte but want different JSON shapes (raw JSON passthrough vs base64). +func jsonValueWithOID(v any, oid uint32) any { + if v == nil { + return nil + } + + switch oid { + case pgtype.JSONOID, pgtype.JSONBOID: + // pgx returns json/jsonb as already-decoded Go values when no codec + // is registered; with the default codec, they decode to map/slice/etc. + // In QueryExecModeExec text-mode, pgx returns the raw JSON bytes as + // string (since the wire is text). We accept both shapes. + switch x := v.(type) { + case []byte: + return json.RawMessage(x) + case string: + return json.RawMessage(x) + } + case pgtype.ByteaOID: + if b, ok := v.([]byte); ok { + return base64.StdEncoding.EncodeToString(b) + } + } + + return jsonValue(v) +} diff --git a/experimental/postgres/cmd/value_test.go b/experimental/postgres/cmd/value_test.go new file mode 100644 index 00000000000..092fc6f7284 --- /dev/null +++ b/experimental/postgres/cmd/value_test.go @@ -0,0 +1,84 @@ +package postgrescmd + +import ( + "encoding/json" + "math" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" +) + +func TestJSONValue_PrimitiveTypes(t *testing.T) { + assert.Equal(t, true, jsonValue(true)) + assert.Equal(t, "hello", jsonValue("hello")) + assert.Equal(t, int64(42), jsonValue(int64(42))) + assert.InDelta(t, 3.14, jsonValue(float64(3.14)), 1e-9) +} + +func TestJSONValue_NULL(t *testing.T) { + assert.Nil(t, jsonValue(nil)) +} + +func TestJSONValue_FloatSpecials(t *testing.T) { + assert.Equal(t, "NaN", jsonValue(math.NaN())) + assert.Equal(t, "Infinity", jsonValue(math.Inf(1))) + assert.Equal(t, "-Infinity", jsonValue(math.Inf(-1))) +} + +func TestJSONValue_LargeIntPreservedAsString(t *testing.T) { + big := int64(1<<53 + 1) + assert.Equal(t, "9007199254740993", jsonValue(big)) + + negBig := -int64(1<<53 + 1) + assert.Equal(t, "-9007199254740993", jsonValue(negBig)) +} + +func TestJSONValue_SafeIntPreservedAsNumber(t *testing.T) { + safe := int64(1<<53 - 1) + assert.Equal(t, safe, jsonValue(safe)) +} + +func TestJSONValue_TimestampToRFC3339(t *testing.T) { + tm := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + v := jsonValue(tm) + assert.Equal(t, "2024-01-15T10:30:00Z", v) +} + +func TestJSONValueWithOID_JSONBPassthrough(t *testing.T) { + raw := []byte(`{"id":9007199254740993,"name":"alice"}`) + v := jsonValueWithOID(raw, pgtype.JSONBOID) + + encoded, err := json.Marshal(v) + assert.NoError(t, err) + assert.JSONEq(t, string(raw), string(encoded)) +} + +func TestJSONValueWithOID_ByteaToBase64(t *testing.T) { + v := jsonValueWithOID([]byte{0xde, 0xad, 0xbe, 0xef}, pgtype.ByteaOID) + assert.Equal(t, "3q2+7w==", v) +} + +func TestJSONValueWithOID_FallsBackToJSONValue(t *testing.T) { + assert.Equal(t, int64(42), jsonValueWithOID(int64(42), pgtype.Int8OID)) + assert.Nil(t, jsonValueWithOID(nil, pgtype.TextOID)) +} + +func TestTextValue_NULL(t *testing.T) { + assert.Equal(t, "NULL", textValue(nil)) +} + +func TestTextValue_Bool(t *testing.T) { + assert.Equal(t, "t", textValue(true)) + assert.Equal(t, "f", textValue(false)) +} + +func TestTextValue_BytesAsHex(t *testing.T) { + assert.Equal(t, `\xdeadbeef`, textValue([]byte{0xde, 0xad, 0xbe, 0xef})) +} + +func TestTextValue_Time(t *testing.T) { + tm := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + assert.Equal(t, "2024-01-15T10:30:00Z", textValue(tm)) +} From fdd0e1b8cd8778281d18aa022c0ee2a6567b06f1 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 11:15:34 +0200 Subject: [PATCH 06/25] Address PR 2 review feedback round 1 - Fix non-finite float text: textValue had no float branch and fell through to fmt.Sprintf, which emits +Inf/-Inf instead of Postgres' Infinity/-Infinity. Added the explicit float case + tests. - Emit JSON object keys in column order, not alphabetical. The map approach inadvertently sorted keys; switched to manual ordered emission (write '{', encode key:value pairs in column order, write '}'). Added a regression test with non-alphabetical column names. - Honor --output text on a pipe instead of silently rewriting to JSON. Repo rule says "reject incompatible inputs early; never silently ignore a flag the current mode can't honor". Auto-fallback now only fires when the flag was not explicitly set (or not pinned by env). - Trim impossible Go types from jsonValue (pgx never decodes int8 / uint8/16/32 / uint64 from PG columns). - Drop the redundant ReadWriteDns guard in resolveProvisioned; an AVAILABLE Lakebase instance is documented to have DNS, and cmd/psql doesn't carry the same guard. - Build the unsupported-format error from allOutputFormats so the message stays in sync if a fourth format is added. - Update execute.go's QueryExecModeExec doc to acknowledge that we now call rows.Values() (not RawValues), so all sinks see Go-typed input. - Collapse empty rows-producing JSON to "[\n]\n" and matching OnError. - Add stderr warning helper (commandTagRowCount now covered for MERGE/COPY/FETCH/MOVE). - Test gaps: text +Inf, text finite float, JSON column order, OnError for csv/text sinks, CSV with embedded newline + quote. Co-authored-by: Isaac --- .../query/provisioned-targeting/output.txt | 4 - .../query/provisioned-targeting/script | 3 - .../query/provisioned-targeting/test.toml | 9 -- experimental/postgres/cmd/execute.go | 7 +- experimental/postgres/cmd/output.go | 21 +++- experimental/postgres/cmd/output_test.go | 18 +++- experimental/postgres/cmd/query.go | 5 + experimental/postgres/cmd/render_csv_test.go | 20 ++++ experimental/postgres/cmd/render_json.go | 97 +++++++++++++------ experimental/postgres/cmd/render_json_test.go | 15 ++- experimental/postgres/cmd/render_test.go | 12 +++ experimental/postgres/cmd/targeting.go | 3 - experimental/postgres/cmd/value.go | 33 +++++-- experimental/postgres/cmd/value_test.go | 11 +++ 14 files changed, 192 insertions(+), 66 deletions(-) diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt index 0f00f8b3e44..bb7ebe1ee69 100644 --- a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt @@ -3,10 +3,6 @@ >>> musterr [CLI] experimental postgres query --target starting-instance SELECT 1 Error: database instance "starting-instance" is not ready for accepting connections (state: STARTING) -=== Provisioned target with no DNS should error: ->>> musterr [CLI] experimental postgres query --target no-dns-instance SELECT 1 -Error: database instance "no-dns-instance" has no read/write DNS yet - === Provisioned target not found should surface SDK 404: >>> musterr [CLI] experimental postgres query --target missing-instance SELECT 1 Error: failed to get database instance: instance not found diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script index d8995c62a6c..5459e01dfcc 100644 --- a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script @@ -1,8 +1,5 @@ title "Provisioned target in non-AVAILABLE state should error:" trace musterr $CLI experimental postgres query --target starting-instance "SELECT 1" -title "Provisioned target with no DNS should error:" -trace musterr $CLI experimental postgres query --target no-dns-instance "SELECT 1" - title "Provisioned target not found should surface SDK 404:" trace musterr $CLI experimental postgres query --target missing-instance "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml index 4821dab5741..25513a7a975 100644 --- a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml @@ -10,15 +10,6 @@ Response.Body = ''' } ''' -[[Server]] -Pattern = "GET /api/2.0/database/instances/no-dns-instance" -Response.Body = ''' -{ - "name": "no-dns-instance", - "state": "AVAILABLE" -} -''' - [[Server]] Pattern = "GET /api/2.0/database/instances/missing-instance" Response.StatusCode = 404 diff --git a/experimental/postgres/cmd/execute.go b/experimental/postgres/cmd/execute.go index 61d93bd7bc2..51f70f836a9 100644 --- a/experimental/postgres/cmd/execute.go +++ b/experimental/postgres/cmd/execute.go @@ -36,8 +36,11 @@ type rowSink interface { // closed at the end of the command, so the cached prepared statement // never gets reused. // 2. Exec mode uses Postgres' extended-protocol "exec" path with text-format -// result columns, which keeps the canonical-Postgres-text rendering for -// --output text and --output csv straightforward. +// result columns. We still call rows.Values() (not RawValues) so all +// three sinks see uniform Go-typed input; jsonValue/textValue then map +// those types back to canonical strings for text/CSV and to JSON-typed +// values for JSON. The wire format being text means pgx's decode is +// cheap (text -> Go) rather than binary -> Go. // // QueryExecModeExec still uses extended protocol with a single statement and // no implicit transaction wrap, so transaction-disallowed DDL like diff --git a/experimental/postgres/cmd/output.go b/experimental/postgres/cmd/output.go index c293b424b73..9976cd0d548 100644 --- a/experimental/postgres/cmd/output.go +++ b/experimental/postgres/cmd/output.go @@ -34,34 +34,45 @@ var allOutputFormats = []outputFormat{outputText, outputJSON, outputCSV} // values are silently ignored, matching cmd/root/io.go and aitools). // 3. The flag default ("text"). // -// Then the auto-selection rule applies: text on a non-TTY stdout falls back -// to JSON. This matches the aitools query command and means scripts piping -// stdout get machine-readable output by default. +// Then the auto-selection rule applies: a *defaulted* text mode on a non-TTY +// stdout falls back to JSON, so scripts piping the output get machine- +// readable output by default. An *explicit* --output text is honoured even +// on a pipe; per CLAUDE.md we don't silently override flags the user set. // // flagSet is true if the user explicitly passed --output. stdoutTTY is true // if stdout is a terminal. func resolveOutputFormat(ctx context.Context, flagValue string, flagSet, stdoutTTY bool) (outputFormat, error) { chosen := outputFormat(strings.ToLower(flagValue)) + chosenExplicit := flagSet if !flagSet { if v, ok := env.Lookup(ctx, envOutputFormat); ok { candidate := outputFormat(strings.ToLower(v)) if isKnownOutputFormat(candidate) { chosen = candidate + chosenExplicit = true } } } if !isKnownOutputFormat(chosen) { - return "", fmt.Errorf("unsupported output format %q; expected one of: text, json, csv", flagValue) + return "", fmt.Errorf("unsupported output format %q; expected one of: %s", flagValue, joinOutputFormats(allOutputFormats)) } - if chosen == outputText && !stdoutTTY { + if chosen == outputText && !stdoutTTY && !chosenExplicit { return outputJSON, nil } return chosen, nil } +func joinOutputFormats(formats []outputFormat) string { + parts := make([]string, len(formats)) + for i, f := range formats { + parts[i] = string(f) + } + return strings.Join(parts, ", ") +} + func isKnownOutputFormat(f outputFormat) bool { switch f { case outputText, outputJSON, outputCSV: diff --git a/experimental/postgres/cmd/output_test.go b/experimental/postgres/cmd/output_test.go index 79289a43e56..4598085805a 100644 --- a/experimental/postgres/cmd/output_test.go +++ b/experimental/postgres/cmd/output_test.go @@ -23,11 +23,25 @@ func TestResolveOutputFormat_TextOnPipeFallsBackToJSON(t *testing.T) { assert.Equal(t, outputJSON, got) } -func TestResolveOutputFormat_ExplicitTextOnPipeAlsoFallsBackToJSON(t *testing.T) { +func TestResolveOutputFormat_ExplicitTextOnPipeIsHonoured(t *testing.T) { ctx := t.Context() got, err := resolveOutputFormat(ctx, "text", true, false) require.NoError(t, err) - assert.Equal(t, outputJSON, got) + assert.Equal(t, outputText, got) +} + +func TestResolveOutputFormat_EnvVarTextOnPipeIsHonoured(t *testing.T) { + ctx := env.Set(t.Context(), envOutputFormat, "text") + got, err := resolveOutputFormat(ctx, "text", false, false) + require.NoError(t, err) + assert.Equal(t, outputText, got) +} + +func TestResolveOutputFormat_EnvVarCSVOnPipe(t *testing.T) { + ctx := env.Set(t.Context(), envOutputFormat, "csv") + got, err := resolveOutputFormat(ctx, "text", false, false) + require.NoError(t, err) + assert.Equal(t, outputCSV, got) } func TestResolveOutputFormat_ExplicitJSON(t *testing.T) { diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index c3078f24d82..2b4f12694f9 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -120,6 +120,11 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) return err } + // SupportsColor is the public TTY-ish signal libs/cmdio exposes today; it + // also folds in NO_COLOR / TERM=dumb, which strictly speaking are colour + // preferences rather than TTY signals. Users who hit that edge case can + // pass --output text explicitly; that path is honoured (see + // resolveOutputFormat). Mirrors the aitools query command. stdoutTTY := cmdio.SupportsColor(ctx, cmd.OutOrStdout()) format, err := resolveOutputFormat(ctx, f.outputFormat, f.outputFormatSet, stdoutTTY) if err != nil { diff --git a/experimental/postgres/cmd/render_csv_test.go b/experimental/postgres/cmd/render_csv_test.go index 35d1c3596f1..5a3ee277e2c 100644 --- a/experimental/postgres/cmd/render_csv_test.go +++ b/experimental/postgres/cmd/render_csv_test.go @@ -47,3 +47,23 @@ func TestCSVSink_QuotesFieldsWithCommas(t *testing.T) { require.NoError(t, s.End("SELECT 1")) assert.Contains(t, stdout.String(), `"a,b"`) } + +func TestCSVSink_EmbeddedNewlineAndQuote(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newCSVSink(&stdout, &stderr) + require.NoError(t, s.Begin(fields("note"))) + require.NoError(t, s.Row([]any{"line1\nline2 \"quoted\""})) + require.NoError(t, s.End("SELECT 1")) + assert.Contains(t, stdout.String(), "\"line1\nline2 \"\"quoted\"\"\"") +} + +func TestCSVSink_OnError_NoOp(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newCSVSink(&stdout, &stderr) + require.NoError(t, s.Begin(fields("id"))) + require.NoError(t, s.Row([]any{int64(1)})) + s.OnError(assert.AnError) + // CSV has no open structure to close; partial row count plus header is + // what the consumer sees. The sink must not panic on OnError. + assert.Contains(t, stdout.String(), "id\n1\n") +} diff --git a/experimental/postgres/cmd/render_json.go b/experimental/postgres/cmd/render_json.go index 1d9a53a8e8d..dc713b6b786 100644 --- a/experimental/postgres/cmd/render_json.go +++ b/experimental/postgres/cmd/render_json.go @@ -82,53 +82,84 @@ func (s *jsonSink) Row(values []any) error { } } - // Build the row object as a *map* of column to converted value, then let - // json.Marshal handle the encoding. We don't preserve key insertion order - // (json package sorts map keys), which is fine for machine consumers; the - // columns slice is the canonical order. - // - // Using ordered emission would require a manual writer. Worth the cost - // only if a downstream consumer needs schema-positional output, which - // none do today. - obj := make(map[string]any, len(s.columns)) + // Emit keys in column order. json.Marshal on a map sorts keys + // alphabetically; SELECT order is what consumers expect, so we write + // `{`, walk columns, encode key:value pairs ourselves, then `}`. + if _, err := io.WriteString(s.out, "{"); err != nil { + return err + } for i, name := range s.columns { - obj[name] = jsonValueWithOID(values[i], s.oids[i]) + if i > 0 { + if _, err := io.WriteString(s.out, ","); err != nil { + return err + } + } + key, err := marshalJSON(name) + if err != nil { + return fmt.Errorf("encode column name %q: %w", name, err) + } + if _, err := s.out.Write(key); err != nil { + return err + } + if _, err := io.WriteString(s.out, ":"); err != nil { + return err + } + val, err := marshalJSON(jsonValueWithOID(values[i], s.oids[i])) + if err != nil { + return fmt.Errorf("encode value for %q: %w", name, err) + } + if _, err := s.out.Write(val); err != nil { + return err + } } + if _, err := io.WriteString(s.out, "}"); err != nil { + return err + } + s.rowsWritten++ + return nil +} +// marshalJSON encodes v with HTML escaping disabled (so jsonb values like +// {"url":""} round-trip without `<` rewrites). encoding/json's Encoder +// is the only path that exposes SetEscapeHTML, so we route through it and +// strip the trailing newline it always appends. +func marshalJSON(v any) ([]byte, error) { var buf bytes.Buffer enc := json.NewEncoder(&buf) enc.SetEscapeHTML(false) - if err := enc.Encode(obj); err != nil { - return fmt.Errorf("encode row: %w", err) - } - // json.Encoder always writes a trailing newline; trim it so our outer - // formatting controls the layout. - out := bytes.TrimRight(buf.Bytes(), "\n") - if _, err := s.out.Write(out); err != nil { - return err + if err := enc.Encode(v); err != nil { + return nil, err } - s.rowsWritten++ - return nil + return bytes.TrimRight(buf.Bytes(), "\n"), nil } func (s *jsonSink) End(commandTag string) error { if s.hasOpenedArray { + if s.rowsWritten == 0 { + // Empty result: collapse to "[]\n" rather than "[\n\n]\n". + _, err := io.WriteString(s.out, "]\n") + return err + } _, err := io.WriteString(s.out, "\n]\n") return err } - // Command-only path: emit a single object. - obj := map[string]any{"command": commandTagVerb(commandTag)} - if rows, ok := commandTagRowCount(commandTag); ok { - obj["rows_affected"] = rows + // Command-only path: emit a single ordered object. + if _, err := io.WriteString(s.out, `{"command":`); err != nil { + return err } - - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.SetEscapeHTML(false) - if err := enc.Encode(obj); err != nil { - return fmt.Errorf("encode command tag: %w", err) + verbBytes, err := marshalJSON(commandTagVerb(commandTag)) + if err != nil { + return fmt.Errorf("encode command tag verb: %w", err) + } + if _, err := s.out.Write(verbBytes); err != nil { + return err } - _, err := s.out.Write(buf.Bytes()) + if rows, ok := commandTagRowCount(commandTag); ok { + if _, err := fmt.Fprintf(s.out, `,"rows_affected":%d`, rows); err != nil { + return err + } + } + _, err = io.WriteString(s.out, "}\n") return err } @@ -141,6 +172,10 @@ func (s *jsonSink) OnError(err error) { } // Best-effort; if this Write fails the stream is already corrupted // and there is nothing more we can do. + if s.rowsWritten == 0 { + _, _ = io.WriteString(s.out, "]\n") + return + } _, _ = io.WriteString(s.out, "\n]\n") } diff --git a/experimental/postgres/cmd/render_json_test.go b/experimental/postgres/cmd/render_json_test.go index a2617b27bc6..26aa79cc832 100644 --- a/experimental/postgres/cmd/render_json_test.go +++ b/experimental/postgres/cmd/render_json_test.go @@ -42,7 +42,16 @@ func TestJSONSink_EmptyRowsProducing(t *testing.T) { s := newJSONSink(&stdout, &stderr) require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}))) require.NoError(t, s.End("SELECT 0")) - assert.Equal(t, "[\n\n]\n", stdout.String()) + assert.Equal(t, "[\n]\n", stdout.String()) +} + +func TestJSONSink_KeysInColumnOrder(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(fieldsWithOIDs([]string{"b", "a"}, []uint32{pgtype.Int8OID, pgtype.Int8OID}))) + require.NoError(t, s.Row([]any{int64(1), int64(2)})) + require.NoError(t, s.End("SELECT 1")) + assert.Equal(t, "[\n"+`{"b":1,"a":2}`+"\n]\n", stdout.String()) } func TestJSONSink_CommandOnly_WithRowCount(t *testing.T) { @@ -104,6 +113,10 @@ func TestCommandTagParse(t *testing.T) { {"UPDATE 3", "UPDATE", 3, true}, {"DELETE 0", "DELETE", 0, true}, {"SELECT 100", "SELECT", 100, true}, + {"MERGE 5", "MERGE", 5, true}, + {"COPY 1000", "COPY", 1000, true}, + {"FETCH 7", "FETCH", 7, true}, + {"MOVE 3", "MOVE", 3, true}, {"CREATE DATABASE", "CREATE", 0, false}, {"SET", "SET", 0, false}, } diff --git a/experimental/postgres/cmd/render_test.go b/experimental/postgres/cmd/render_test.go index 06190323e43..d451febb191 100644 --- a/experimental/postgres/cmd/render_test.go +++ b/experimental/postgres/cmd/render_test.go @@ -71,3 +71,15 @@ func TestTextSink_NULLRendersAsNULL(t *testing.T) { require.NoError(t, s.End("SELECT 1")) assert.Contains(t, buf.String(), "NULL") } + +func TestTextSink_OnError_NoOp(t *testing.T) { + var buf bytes.Buffer + s := newTextSink(&buf) + require.NoError(t, s.Begin(fields("id"))) + require.NoError(t, s.Row([]any{int64(1)})) + s.OnError(assert.AnError) + // Text sink has no open structure to close. OnError must not panic and + // must not emit a partial table; the partial result lives in s.rows but + // is never flushed. + assert.Empty(t, buf.String()) +} diff --git a/experimental/postgres/cmd/targeting.go b/experimental/postgres/cmd/targeting.go index 78e230adaac..4c46bee02dc 100644 --- a/experimental/postgres/cmd/targeting.go +++ b/experimental/postgres/cmd/targeting.go @@ -112,9 +112,6 @@ func resolveProvisioned(ctx context.Context, w *databricks.WorkspaceClient, inst if instance.State != database.DatabaseInstanceStateAvailable { return nil, fmt.Errorf("database instance %q is not ready for accepting connections (state: %s)", instance.Name, instance.State) } - if instance.ReadWriteDns == "" { - return nil, fmt.Errorf("database instance %q has no read/write DNS yet", instance.Name) - } user, err := w.CurrentUser.Me(ctx) if err != nil { diff --git a/experimental/postgres/cmd/value.go b/experimental/postgres/cmd/value.go index 3049b44a82a..21beedd04f0 100644 --- a/experimental/postgres/cmd/value.go +++ b/experimental/postgres/cmd/value.go @@ -25,6 +25,11 @@ const safeIntegerBound = 1<<53 - 1 // NULL renders as the literal "NULL" so it lines up with the column rather // than appearing as an empty cell. CSV converts that back to an empty field // at write time (matches `psql --csv`). +// +// Floats are rendered with Postgres' canonical wording for the IEEE specials +// ("NaN" / "Infinity" / "-Infinity"), not Go's `fmt.Sprintf("%v")` defaults +// (which would emit "+Inf"/"-Inf"). This keeps text/CSV consistent with what +// `psql` would print. func textValue(v any) string { if v == nil { return "NULL" @@ -40,6 +45,10 @@ func textValue(v any) string { return "t" } return "f" + case float64: + return floatTextForm(x) + case float32: + return floatTextForm(float64(x)) case time.Time: return x.Format(time.RFC3339Nano) case fmt.Stringer: @@ -49,6 +58,20 @@ func textValue(v any) string { return fmt.Sprintf("%v", v) } +// floatTextForm formats a float using Postgres' canonical text wording for +// the IEEE specials and Go's shortest-round-trip 'g' format otherwise. +func floatTextForm(f float64) string { + switch { + case math.IsNaN(f): + return "NaN" + case math.IsInf(f, 1): + return "Infinity" + case math.IsInf(f, -1): + return "-Infinity" + } + return strconv.FormatFloat(f, 'g', -1, 64) +} + // jsonValue renders a Go value (as decoded by pgx) to a JSON-encodable // representation. Returns a value the standard json.Marshal can handle // directly and the JSON shape we want; never returns Go values that would @@ -73,18 +96,16 @@ func jsonValue(v any) any { return x case string: return x - case int8, int16, int32, int, uint8, uint16, uint32: + case int16, int32: return x case int64: + // pgx decodes Postgres int8 to Go int64. Outside the IEEE-754 safe + // integer range we render as a string so JavaScript-style consumers + // don't silently lose precision. if x > safeIntegerBound || x < -safeIntegerBound { return strconv.FormatInt(x, 10) } return x - case uint64: - if x > safeIntegerBound { - return strconv.FormatUint(x, 10) - } - return x case float32: return jsonFloat(float64(x)) case float64: diff --git a/experimental/postgres/cmd/value_test.go b/experimental/postgres/cmd/value_test.go index 092fc6f7284..d52edae90bc 100644 --- a/experimental/postgres/cmd/value_test.go +++ b/experimental/postgres/cmd/value_test.go @@ -82,3 +82,14 @@ func TestTextValue_Time(t *testing.T) { tm := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) assert.Equal(t, "2024-01-15T10:30:00Z", textValue(tm)) } + +func TestTextValue_FloatSpecials(t *testing.T) { + assert.Equal(t, "NaN", textValue(math.NaN())) + assert.Equal(t, "Infinity", textValue(math.Inf(1))) + assert.Equal(t, "-Infinity", textValue(math.Inf(-1))) +} + +func TestTextValue_FiniteFloat(t *testing.T) { + assert.Equal(t, "3.14", textValue(float64(3.14))) + assert.Equal(t, "0", textValue(float64(0))) +} From 287dd62aab56a80597288b85c30f9068e09b1a11 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 11:21:18 +0200 Subject: [PATCH 07/25] Address PR 2 review feedback round 2 - Doc fix: textSink.OnError doc said "prints whatever rows have been collected" but text mode buffers everything to End. New doc states the buffered partial result is discarded on iteration error. - Doc fix: textValue float comment overstated psql parity. Tightened to acknowledge Go's 'g' format may differ from psql in exponential vs fixed notation around the boundary. - Tighten OnError contract: explicitly states it is NOT called when Begin itself errors. - Replace switch-by-format in isKnownOutputFormat with slices.Contains on allOutputFormats so adding a fourth format is one edit. - Tighten command-only JSON tests from JSONEq (key-order ignored) to byte-equal so a future field addition is caught. - Tighten JSONSink_OnError tests to byte-equal; add the Begin-but-no-rows case which exercises the rowsWritten==0 branch. Co-authored-by: Isaac --- experimental/postgres/cmd/execute.go | 8 +++++--- experimental/postgres/cmd/output.go | 7 ++----- experimental/postgres/cmd/render.go | 7 ++++--- experimental/postgres/cmd/render_json_test.go | 17 ++++++++++++----- experimental/postgres/cmd/value.go | 10 ++++++---- 5 files changed, 29 insertions(+), 20 deletions(-) diff --git a/experimental/postgres/cmd/execute.go b/experimental/postgres/cmd/execute.go index 51f70f836a9..8d0b896031c 100644 --- a/experimental/postgres/cmd/execute.go +++ b/experimental/postgres/cmd/execute.go @@ -21,9 +21,11 @@ type rowSink interface { Row(values []any) error // End is called once on successful completion. End(commandTag string) error - // OnError is called if iteration errors after Begin returned. The sink - // is expected to flush any in-progress output structures so stdout - // remains well-formed. The caller still surfaces err to its caller. + // OnError is called if iteration errors after Begin returned successfully. + // The sink is expected to flush any in-progress output structures so + // stdout remains well-formed. The caller still surfaces err to its caller. + // If Begin itself errors, OnError is NOT called: sinks must not write any + // framing before Begin succeeds. OnError(err error) } diff --git a/experimental/postgres/cmd/output.go b/experimental/postgres/cmd/output.go index 9976cd0d548..e5b59fec96f 100644 --- a/experimental/postgres/cmd/output.go +++ b/experimental/postgres/cmd/output.go @@ -3,6 +3,7 @@ package postgrescmd import ( "context" "fmt" + "slices" "strings" "github.com/databricks/cli/libs/env" @@ -74,9 +75,5 @@ func joinOutputFormats(formats []outputFormat) string { } func isKnownOutputFormat(f outputFormat) bool { - switch f { - case outputText, outputJSON, outputCSV: - return true - } - return false + return slices.Contains(allOutputFormats, f) } diff --git a/experimental/postgres/cmd/render.go b/experimental/postgres/cmd/render.go index bc45c89e0d0..2e1daf6376b 100644 --- a/experimental/postgres/cmd/render.go +++ b/experimental/postgres/cmd/render.go @@ -62,9 +62,10 @@ func (s *textSink) End(commandTag string) error { return err } -// OnError for text sinks is a no-op: text output prints whatever rows have -// already been collected, with no open structure to close. The caller -// surfaces the error separately (cobra's default error rendering). +// OnError for text sinks is a no-op. Text mode buffers all rows for +// tabwriter alignment, so a partial result is discarded on iteration error; +// only cobra's error message reaches the user. The streaming sinks (json, +// csv) handle the partial-result case themselves. func (s *textSink) OnError(err error) {} func headerSeparator(cols []string) []string { diff --git a/experimental/postgres/cmd/render_json_test.go b/experimental/postgres/cmd/render_json_test.go index 26aa79cc832..4e6f474d257 100644 --- a/experimental/postgres/cmd/render_json_test.go +++ b/experimental/postgres/cmd/render_json_test.go @@ -59,7 +59,9 @@ func TestJSONSink_CommandOnly_WithRowCount(t *testing.T) { s := newJSONSink(&stdout, &stderr) require.NoError(t, s.Begin(nil)) require.NoError(t, s.End("INSERT 0 5")) - assert.JSONEq(t, `{"command":"INSERT","rows_affected":5}`, stdout.String()) + // Byte-equal: pins the field order so adding a future field (e.g. last_oid) + // must update the test rather than silently drift. + assert.Equal(t, `{"command":"INSERT","rows_affected":5}`+"\n", stdout.String()) } func TestJSONSink_CommandOnly_NoRowCount(t *testing.T) { @@ -67,7 +69,7 @@ func TestJSONSink_CommandOnly_NoRowCount(t *testing.T) { s := newJSONSink(&stdout, &stderr) require.NoError(t, s.Begin(nil)) require.NoError(t, s.End("CREATE DATABASE")) - assert.JSONEq(t, `{"command":"CREATE"}`, stdout.String()) + assert.Equal(t, `{"command":"CREATE"}`+"\n", stdout.String()) } func TestJSONSink_DuplicateColumns(t *testing.T) { @@ -89,10 +91,15 @@ func TestJSONSink_OnError_AfterRows(t *testing.T) { require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}))) require.NoError(t, s.Row([]any{int64(1)})) s.OnError(assert.AnError) + assert.Equal(t, "[\n"+`{"id":1}`+"\n]\n", stdout.String()) +} - assert.Contains(t, stdout.String(), "[\n") - assert.Contains(t, stdout.String(), `{"id":1}`) - assert.Contains(t, stdout.String(), "\n]\n") +func TestJSONSink_OnError_AfterBeginNoRows(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}))) + s.OnError(assert.AnError) + assert.Equal(t, "[\n]\n", stdout.String()) } func TestJSONSink_OnError_BeforeBegin(t *testing.T) { diff --git a/experimental/postgres/cmd/value.go b/experimental/postgres/cmd/value.go index 21beedd04f0..1578c7efecf 100644 --- a/experimental/postgres/cmd/value.go +++ b/experimental/postgres/cmd/value.go @@ -26,10 +26,12 @@ const safeIntegerBound = 1<<53 - 1 // than appearing as an empty cell. CSV converts that back to an empty field // at write time (matches `psql --csv`). // -// Floats are rendered with Postgres' canonical wording for the IEEE specials -// ("NaN" / "Infinity" / "-Infinity"), not Go's `fmt.Sprintf("%v")` defaults -// (which would emit "+Inf"/"-Inf"). This keeps text/CSV consistent with what -// `psql` would print. +// IEEE special floats use Postgres' canonical wording ("NaN" / "Infinity" +// / "-Infinity"), not Go's `fmt.Sprintf("%v")` defaults (which would emit +// "+Inf"/"-Inf"). Finite floats use Go's shortest-round-trip 'g' format, +// which may differ from psql in exponential vs fixed notation around the +// 'g' boundary (e.g. Go prints `1e+10`; psql prints `10000000000`). Full +// psql parity is not worth a custom formatter. func textValue(v any) string { if v == nil { return "NULL" From 495849d64a2b68eb1426743c4a624790fca96129 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 11:30:21 +0200 Subject: [PATCH 08/25] Multi-input + multi-statement rejection + pg error formatting This is PR 3 of the experimental postgres query stack. Adds the rest of the input ergonomics promised in the plan and the error-formatting polish. Inputs: positional args become variadic, --file is repeatable, stdin is read when neither is present, and a positional ending in '.sql' that exists on disk is treated as a SQL file. Execution order is files-first then positionals (cobra/pflag does not preserve interleaved spelling, documented in --help). Each input unit must contain exactly one statement. checkSingleStatement walks the SQL with a hand-written conservative scanner that ignores ';' inside single-quoted strings, double-quoted identifiers, line comments, block comments, and dollar-quoted bodies. Multi-statement strings are rejected before connect with a hint pointing at the multi-input alternatives. Multi-input output: - text: each per-unit result rendered inline, separated by a blank line (mirrors psql's compact text shape). - json: top-level array of per-unit result objects with shape {"sql","kind","elapsed_ms",...}; rows-producing units carry a "rows":[...] array, command-only carry "command"+"rows_affected". Each per-unit object is buffered to completion before write; the outer array streams across units. The plan accepts this trade-off: huge SELECTs in multi-input invocations buffer. - csv: rejected pre-flight when N>1 (no sensible cross-schema shape). Single-input csv keeps streaming. Per-unit errors render as a {"kind":"error", ...} entry in the JSON shape so scripts can detect failure without checking exit code. Sequential execution stops on the first failing unit; the successful prefix is rendered. formatPgError renders *pgconn.PgError with SEVERITY, SQLSTATE, DETAIL, HINT inline. Non-PgError values pass through unchanged so connect-time errors keep their original wording. Single-input keeps the streaming sinks from PR 2; only multi-input goes through the buffered renderer. Session state (SET, temp tables) carries across input units because they share one connection. TUI for >30 rows is deferred to a follow-up. The current text path uses the static tabwriter table for both single- and multi-input. Co-authored-by: Isaac --- .../postgres/query/argument-errors/output.txt | 23 +- .../postgres/query/argument-errors/script | 14 ++ experimental/postgres/cmd/error.go | 42 ++++ experimental/postgres/cmd/error_test.go | 48 ++++ experimental/postgres/cmd/inputs.go | 102 +++++++++ experimental/postgres/cmd/inputs_test.go | 101 +++++++++ experimental/postgres/cmd/multistatement.go | 159 +++++++++++++ .../postgres/cmd/multistatement_test.go | 54 +++++ experimental/postgres/cmd/query.go | 147 +++++++++--- experimental/postgres/cmd/render_multi.go | 209 ++++++++++++++++++ .../postgres/cmd/render_multi_test.go | 89 ++++++++ experimental/postgres/cmd/result.go | 62 ++++++ 12 files changed, 1017 insertions(+), 33 deletions(-) create mode 100644 experimental/postgres/cmd/error.go create mode 100644 experimental/postgres/cmd/error_test.go create mode 100644 experimental/postgres/cmd/inputs.go create mode 100644 experimental/postgres/cmd/inputs_test.go create mode 100644 experimental/postgres/cmd/multistatement.go create mode 100644 experimental/postgres/cmd/multistatement_test.go create mode 100644 experimental/postgres/cmd/render_multi.go create mode 100644 experimental/postgres/cmd/render_multi_test.go create mode 100644 experimental/postgres/cmd/result.go diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt index 238e099299c..3b6fe7910a4 100644 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt @@ -1,11 +1,11 @@ === No SQL argument should error: >>> musterr [CLI] experimental postgres query --target projects/foo -Error: accepts 1 arg(s), received 0 +Error: no SQL provided === Empty SQL should error: >>> musterr [CLI] experimental postgres query --target projects/foo -Error: no SQL provided +Error: argv[1] is empty === Neither targeting form should error: >>> musterr [CLI] experimental postgres query SELECT 1 @@ -42,3 +42,22 @@ Error: invalid resource path: missing project ID === Trailing components after endpoint should error: >>> musterr [CLI] experimental postgres query --target projects/foo/branches/bar/endpoints/baz/extra SELECT 1 Error: invalid resource path: trailing components after endpoint: projects/foo/branches/bar/endpoints/baz/extra + +=== Multi-statement string should error with hint: +>>> musterr [CLI] experimental postgres query --target projects/foo SELECT 1; SELECT 2 +Error: argv[1]: input contains multiple statements (a ';' separates two or more statements) +This command runs one statement per input. To run multiple statements: + - Pass each as a separate positional: query "SELECT 1" "SELECT 2" + - Pass each in its own --file: query --file q1.sql --file q2.sql + +=== CSV with multiple inputs should reject pre-flight: +>>> musterr [CLI] experimental postgres query --target projects/foo --output csv SELECT 1 SELECT 2 +Error: --output csv requires a single input unit; got 2 (use --output json for multi-input invocations) + +=== Empty file should error: +>>> musterr [CLI] experimental postgres query --target projects/foo --file empty.sql +Error: --file "empty.sql" is empty + +=== Missing file should error: +>>> musterr [CLI] experimental postgres query --target projects/foo --file /tmp/does-not-exist.sql +Error: read --file "/tmp/does-not-exist.sql": open /tmp/does-not-exist.sql: no such file or directory diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/script b/acceptance/cmd/experimental/postgres/query/argument-errors/script index ac6ac42746e..a1401d3b8e4 100644 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/script +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/script @@ -30,3 +30,17 @@ trace musterr $CLI experimental postgres query --target projects/ "SELECT 1" title "Trailing components after endpoint should error:" trace musterr $CLI experimental postgres query --target projects/foo/branches/bar/endpoints/baz/extra "SELECT 1" + +title "Multi-statement string should error with hint:" +trace musterr $CLI experimental postgres query --target projects/foo "SELECT 1; SELECT 2" + +title "CSV with multiple inputs should reject pre-flight:" +trace musterr $CLI experimental postgres query --target projects/foo --output csv "SELECT 1" "SELECT 2" + +title "Empty file should error:" +echo "" > empty.sql +trace musterr $CLI experimental postgres query --target projects/foo --file empty.sql +rm -f empty.sql + +title "Missing file should error:" +trace musterr $CLI experimental postgres query --target projects/foo --file /tmp/does-not-exist.sql diff --git a/experimental/postgres/cmd/error.go b/experimental/postgres/cmd/error.go new file mode 100644 index 00000000000..02278a6c58b --- /dev/null +++ b/experimental/postgres/cmd/error.go @@ -0,0 +1,42 @@ +package postgrescmd + +import ( + "errors" + "fmt" + "strings" + + "github.com/jackc/pgx/v5/pgconn" +) + +// formatPgError renders an error in a friendlier form when it's a Postgres +// server-side error. *pgconn.PgError exposes Code, Severity, Message, Detail, +// Hint, and Position; the plain text form attaches what's set so users see +// SQLSTATE plus any hint upstream included. +// +// For non-PgError values, returns err.Error() unchanged so the caller can +// surface it directly. The richer LINE+caret rendering is out of scope for +// this PR; we stick with the plain shape for now. +func formatPgError(err error) string { + var pgErr *pgconn.PgError + if !errors.As(err, &pgErr) { + return err.Error() + } + + var sb strings.Builder + if pgErr.Severity != "" { + fmt.Fprintf(&sb, "%s: ", pgErr.Severity) + } else { + sb.WriteString("ERROR: ") + } + sb.WriteString(pgErr.Message) + if pgErr.Code != "" { + fmt.Fprintf(&sb, " (SQLSTATE %s)", pgErr.Code) + } + if pgErr.Detail != "" { + fmt.Fprintf(&sb, "\nDETAIL: %s", pgErr.Detail) + } + if pgErr.Hint != "" { + fmt.Fprintf(&sb, "\nHINT: %s", pgErr.Hint) + } + return sb.String() +} diff --git a/experimental/postgres/cmd/error_test.go b/experimental/postgres/cmd/error_test.go new file mode 100644 index 00000000000..f4d709468d1 --- /dev/null +++ b/experimental/postgres/cmd/error_test.go @@ -0,0 +1,48 @@ +package postgrescmd + +import ( + "errors" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" +) + +func TestFormatPgError_NonPgError(t *testing.T) { + err := errors.New("plain error") + assert.Equal(t, "plain error", formatPgError(err)) +} + +func TestFormatPgError_BasicPgError(t *testing.T) { + err := &pgconn.PgError{ + Severity: "ERROR", + Code: "42601", + Message: `syntax error at or near "FRO"`, + } + assert.Equal(t, + `ERROR: syntax error at or near "FRO" (SQLSTATE 42601)`, + formatPgError(err), + ) +} + +func TestFormatPgError_WithDetailAndHint(t *testing.T) { + err := &pgconn.PgError{ + Severity: "ERROR", + Code: "42601", + Message: `syntax error at or near "FRO"`, + Hint: `Did you mean "FROM"?`, + Detail: "more context", + } + got := formatPgError(err) + assert.Contains(t, got, "ERROR:") + assert.Contains(t, got, "(SQLSTATE 42601)") + assert.Contains(t, got, "DETAIL: more context") + assert.Contains(t, got, `HINT: Did you mean "FROM"?`) +} + +func TestFormatPgError_WrappedPgError(t *testing.T) { + pg := &pgconn.PgError{Code: "42501", Message: "permission denied"} + wrapped := errors.New("query failed: " + pg.Error()) + // Plain error doesn't unwrap; falls through to err.Error. + assert.Contains(t, formatPgError(wrapped), "permission denied") +} diff --git a/experimental/postgres/cmd/inputs.go b/experimental/postgres/cmd/inputs.go new file mode 100644 index 00000000000..3cc64d45ad4 --- /dev/null +++ b/experimental/postgres/cmd/inputs.go @@ -0,0 +1,102 @@ +package postgrescmd + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "strings" + + "github.com/databricks/cli/libs/cmdio" +) + +// sqlFileExtension is the file suffix that triggers the .sql autodetect on a +// positional argument: if `databricks ... query foo.sql` exists on disk, we +// read it as a SQL file; otherwise it's treated as literal SQL. +const sqlFileExtension = ".sql" + +// inputUnit is one SQL statement to execute, paired with metadata so the +// renderer can identify its origin in multi-input output shapes. +type inputUnit struct { + // SQL is the trimmed statement text. Always non-empty by the time the + // scanner has rejected multi-statement strings and empty inputs. + SQL string + // Source is a human-readable label for this input ("--file query.sql", + // "stdin", or "argv[1]"). Used by the multi-input JSON renderer's "sql" + // field hint and by the rich error formatter. + Source string +} + +// collectInputs assembles the ordered list of input units from positional +// arguments, --file flags, and stdin. +// +// Execution order is files-first then positionals (plan section "Statement +// execution"). Cobra/pflag does not preserve the user's interleaved CLI +// spelling: it collects all --file flags into one slice and all positionals +// into another, so we cannot honour `--file q1.sql "SELECT 1" --file q2.sql` +// as written. This is documented in --help. +// +// Stdin is read only when neither positional nor --file is provided. +func collectInputs(ctx context.Context, in io.Reader, args, files []string) ([]inputUnit, error) { + var units []inputUnit + + for _, path := range files { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read --file %q: %w", path, err) + } + sql := strings.TrimSpace(string(data)) + if sql == "" { + return nil, fmt.Errorf("--file %q is empty", path) + } + units = append(units, inputUnit{SQL: sql, Source: "--file " + path}) + } + + for i, arg := range args { + // .sql autodetect: if the positional ends in .sql AND the file + // exists, read it as a SQL file. Other read errors (permission + // denied) surface directly. If the file does not exist, fall through + // and treat the positional as literal SQL — useful when the user + // passes a string that happens to end with ".sql". + if strings.HasSuffix(arg, sqlFileExtension) { + data, err := os.ReadFile(arg) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("read positional %q: %w", arg, err) + } + if err == nil { + sql := strings.TrimSpace(string(data)) + if sql == "" { + return nil, fmt.Errorf("positional %q is empty", arg) + } + units = append(units, inputUnit{SQL: sql, Source: arg}) + continue + } + } + sql := strings.TrimSpace(arg) + if sql == "" { + return nil, fmt.Errorf("argv[%d] is empty", i+1) + } + units = append(units, inputUnit{SQL: sql, Source: fmt.Sprintf("argv[%d]", i+1)}) + } + + if len(units) == 0 { + // No positionals, no --file: read from stdin if it's not a prompt- + // supporting TTY. The aitools query helper applies the same rule. + _, isOsFile := in.(*os.File) + if isOsFile && cmdio.IsPromptSupported(ctx) { + return nil, errors.New("no SQL provided; pass a SQL string, use --file, or pipe via stdin") + } + data, err := io.ReadAll(in) + if err != nil { + return nil, fmt.Errorf("read stdin: %w", err) + } + sql := strings.TrimSpace(string(data)) + if sql == "" { + return nil, errors.New("no SQL provided") + } + units = append(units, inputUnit{SQL: sql, Source: "stdin"}) + } + + return units, nil +} diff --git a/experimental/postgres/cmd/inputs_test.go b/experimental/postgres/cmd/inputs_test.go new file mode 100644 index 00000000000..97d3d2abc70 --- /dev/null +++ b/experimental/postgres/cmd/inputs_test.go @@ -0,0 +1,101 @@ +package postgrescmd + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func writeTemp(t *testing.T, name, contents string) string { + t.Helper() + dir := t.TempDir() + p := filepath.Join(dir, name) + require.NoError(t, os.WriteFile(p, []byte(contents), 0o644)) + return p +} + +func TestCollectInputs_PositionalOnly(t *testing.T) { + units, err := collectInputs(t.Context(), strings.NewReader(""), []string{"SELECT 1"}, nil) + require.NoError(t, err) + require.Len(t, units, 1) + assert.Equal(t, "SELECT 1", units[0].SQL) + assert.Equal(t, "argv[1]", units[0].Source) +} + +func TestCollectInputs_MultiplePositionals(t *testing.T) { + units, err := collectInputs(t.Context(), strings.NewReader(""), []string{"SELECT 1", "SELECT 2"}, nil) + require.NoError(t, err) + require.Len(t, units, 2) + assert.Equal(t, "SELECT 1", units[0].SQL) + assert.Equal(t, "SELECT 2", units[1].SQL) +} + +func TestCollectInputs_FileOnly(t *testing.T) { + p := writeTemp(t, "q.sql", "SELECT * FROM t") + units, err := collectInputs(t.Context(), strings.NewReader(""), nil, []string{p}) + require.NoError(t, err) + require.Len(t, units, 1) + assert.Equal(t, "SELECT * FROM t", units[0].SQL) + assert.Contains(t, units[0].Source, "--file") +} + +func TestCollectInputs_FilesFirstThenPositionals(t *testing.T) { + p1 := writeTemp(t, "a.sql", "SELECT 1") + p2 := writeTemp(t, "b.sql", "SELECT 2") + units, err := collectInputs(t.Context(), strings.NewReader(""), []string{"SELECT 3"}, []string{p1, p2}) + require.NoError(t, err) + require.Len(t, units, 3) + assert.Equal(t, "SELECT 1", units[0].SQL) + assert.Equal(t, "SELECT 2", units[1].SQL) + assert.Equal(t, "SELECT 3", units[2].SQL) +} + +func TestCollectInputs_DotSQLAutoDetect(t *testing.T) { + p := writeTemp(t, "data.sql", "SELECT 42") + units, err := collectInputs(t.Context(), strings.NewReader(""), []string{p}, nil) + require.NoError(t, err) + require.Len(t, units, 1) + assert.Equal(t, "SELECT 42", units[0].SQL) +} + +func TestCollectInputs_DotSQLNotExistingFallsThroughToLiteral(t *testing.T) { + units, err := collectInputs(t.Context(), strings.NewReader(""), []string{"/nonexistent/path.sql"}, nil) + require.NoError(t, err) + require.Len(t, units, 1) + assert.Equal(t, "/nonexistent/path.sql", units[0].SQL) +} + +func TestCollectInputs_StdinOnly(t *testing.T) { + units, err := collectInputs(t.Context(), strings.NewReader("SELECT 1\n"), nil, nil) + require.NoError(t, err) + require.Len(t, units, 1) + assert.Equal(t, "SELECT 1", units[0].SQL) + assert.Equal(t, "stdin", units[0].Source) +} + +func TestCollectInputs_StdinIgnoredWhenPositionalsPresent(t *testing.T) { + units, err := collectInputs(t.Context(), strings.NewReader("FROM STDIN"), []string{"SELECT 1"}, nil) + require.NoError(t, err) + require.Len(t, units, 1) + assert.Equal(t, "SELECT 1", units[0].SQL) +} + +func TestCollectInputs_EmptyStdinErrors(t *testing.T) { + _, err := collectInputs(t.Context(), strings.NewReader(""), nil, nil) + assert.ErrorContains(t, err, "no SQL provided") +} + +func TestCollectInputs_EmptyFileErrors(t *testing.T) { + p := writeTemp(t, "empty.sql", "") + _, err := collectInputs(t.Context(), strings.NewReader(""), nil, []string{p}) + assert.ErrorContains(t, err, "is empty") +} + +func TestCollectInputs_EmptyPositional(t *testing.T) { + _, err := collectInputs(t.Context(), strings.NewReader(""), []string{" "}, nil) + assert.ErrorContains(t, err, "is empty") +} diff --git a/experimental/postgres/cmd/multistatement.go b/experimental/postgres/cmd/multistatement.go new file mode 100644 index 00000000000..4c4f976e8e7 --- /dev/null +++ b/experimental/postgres/cmd/multistatement.go @@ -0,0 +1,159 @@ +package postgrescmd + +import ( + "errors" + "strings" +) + +// errMultipleStatements is the typed error returned by checkSingleStatement +// when the input contains more than one ';'-separated statement. The runQuery +// path catches this with errors.Is to attach the multi-input workaround +// pointer in the user-visible message. +var errMultipleStatements = errors.New("input contains multiple statements (a ';' separates two or more statements)") + +// checkSingleStatement walks sql and returns errMultipleStatements if a +// statement-terminating ';' is found anywhere except trailing whitespace. +// +// The scanner ignores ';' inside: +// - single-quoted strings ('a;b', SQL standard doubled-quote escape) +// - double-quoted identifiers ("col;name") +// - line comments (-- ... \n) +// - block comments (/* ... */, non-nesting) +// - dollar-quoted bodies ($tag$ ... $tag$, optional tag) +// +// Over-rejection on weird syntactic edge cases is acceptable: users get a +// clear error and can split into multiple input units. v2 may swap this for +// a real Postgres tokenizer. +func checkSingleStatement(sql string) error { + s := sql + // Trim trailing whitespace once so a single trailing ';' is allowed. + end := len(strings.TrimRight(s, " \t\r\n")) + + i := 0 + for i < end { + c := s[i] + + switch c { + case ';': + // A ';' that's not at end-of-trimmed-input is a separator. + if i < end-1 { + return errMultipleStatements + } + // Trailing ';' is fine. + i++ + + case '\'': + // Single-quoted string. SQL standard escape is '' (doubled). + i = scanQuoted(s, i, end, '\'') + + case '"': + // Double-quoted identifier. Same '"' doubling escape rule. + i = scanQuoted(s, i, end, '"') + + case '-': + // Line comment "--" runs to next newline. + if i+1 < end && s[i+1] == '-' { + i = scanLineComment(s, i, end) + } else { + i++ + } + + case '/': + // Block comment "/* ... */". + if i+1 < end && s[i+1] == '*' { + i = scanBlockComment(s, i, end) + } else { + i++ + } + + case '$': + // Dollar-quoted body: $tag$ ... $tag$ (tag may be empty). + tag, end2 := readDollarTag(s, i, end) + if tag != "" || end2 > i { + i = scanDollarBody(s, end2, end, tag) + } else { + i++ + } + + default: + i++ + } + } + + return nil +} + +// scanQuoted advances past a quoted string or identifier opened at s[start] +// with the given quote character. SQL standard doubles the quote to escape +// (e.g. doubling the quote inside the string). Returns the index of the byte AFTER the closing quote, or +// end if the string is unterminated (over-permissive: an unterminated string +// at EOF means there's no ';' inside it anyway). +func scanQuoted(s string, start, end int, quote byte) int { + i := start + 1 + for i < end { + if s[i] == quote { + if i+1 < end && s[i+1] == quote { + i += 2 // doubled-quote escape + continue + } + return i + 1 + } + i++ + } + return end +} + +func scanLineComment(s string, start, end int) int { + i := start + 2 + for i < end && s[i] != '\n' { + i++ + } + return i +} + +func scanBlockComment(s string, start, end int) int { + i := start + 2 + for i+1 < end { + if s[i] == '*' && s[i+1] == '/' { + return i + 2 + } + i++ + } + return end +} + +// readDollarTag inspects s[start] (which must be '$') and returns the tag +// between the two dollar signs and the index right after the closing first +// '$' of $tag$. If the construct doesn't look like a valid dollar-quote +// opener, returns ("", start) so the caller can fall through. +// +// Tag rule: starts after '$', runs to the next '$', and must consist of +// letter-or-underscore-or-digit (we accept all non-special bytes; over- +// permissive). Empty tag is valid: $$ is a marker, $$body$$ is the body. +func readDollarTag(s string, start, end int) (string, int) { + i := start + 1 + for i < end { + if s[i] == '$' { + tag := s[start+1 : i] + return tag, i + 1 + } + // Stop at characters that can't be in a tag. + if s[i] == ' ' || s[i] == '\t' || s[i] == '\n' || s[i] == ';' { + return "", start + } + i++ + } + return "", start +} + +// scanDollarBody advances past a $tag$...$tag$ body starting at start (the +// byte right after the opening tag's closing '$'). Returns the index of the +// byte AFTER the closing tag, or end if unterminated. +func scanDollarBody(s string, start, end int, tag string) int { + close := "$" + tag + "$" + idx := strings.Index(s[start:end], close) + if idx < 0 { + return end + } + return start + idx + len(close) +} diff --git a/experimental/postgres/cmd/multistatement_test.go b/experimental/postgres/cmd/multistatement_test.go new file mode 100644 index 00000000000..bb50bf5e8ef --- /dev/null +++ b/experimental/postgres/cmd/multistatement_test.go @@ -0,0 +1,54 @@ +package postgrescmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCheckSingleStatement(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {name: "single statement", input: "SELECT 1", wantErr: false}, + {name: "trailing semicolon allowed", input: "SELECT 1;", wantErr: false}, + {name: "trailing semicolon plus whitespace", input: "SELECT 1;\n ", wantErr: false}, + {name: "two statements rejected", input: "SELECT 1; SELECT 2", wantErr: true}, + {name: "two statements with trailing semi", input: "SELECT 1; SELECT 2;", wantErr: true}, + + {name: "semicolon in single-quoted string", input: "SELECT 'a;b'", wantErr: false}, + {name: "semicolon in double-quoted ident", input: `SELECT "col;name" FROM t`, wantErr: false}, + {name: "doubled quote escape", input: "SELECT 'it''s;ok'", wantErr: false}, + {name: "doubled identifier quote", input: `SELECT "x""y;z" FROM t`, wantErr: false}, + + {name: "semicolon in line comment", input: "SELECT 1 -- x;y\n", wantErr: false}, + {name: "semicolon in block comment", input: "SELECT 1 /* x;y */", wantErr: false}, + {name: "block comment unterminated", input: "SELECT 1 /* unterminated", wantErr: false}, + + {name: "semicolon in dollar body untagged", input: "SELECT $$a;b$$", wantErr: false}, + {name: "semicolon in dollar body tagged", input: "SELECT $tag$a;b$tag$", wantErr: false}, + {name: "create function with body", input: "CREATE FUNCTION f() RETURNS int AS $$ BEGIN; END $$ LANGUAGE plpgsql", wantErr: false}, + + {name: "semi inside string then real semi", input: "SELECT 'a;b'; SELECT 2", wantErr: true}, + {name: "semi inside line comment then real semi", input: "SELECT 1 -- ; \n; SELECT 2", wantErr: true}, + {name: "semi inside dollar then real semi", input: "SELECT $$a;b$$; SELECT 2", wantErr: true}, + + {name: "leading whitespace", input: " ;", wantErr: false}, + {name: "empty input", input: "", wantErr: false}, + {name: "only whitespace", input: " \n\t ", wantErr: false}, + {name: "only semicolon", input: ";", wantErr: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := checkSingleStatement(tc.input) + if tc.wantErr { + assert.ErrorIs(t, err, errMultipleStatements) + return + } + assert.NoError(t, err) + }) + } +} diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index 2b4f12694f9..bc339c29d43 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -2,10 +2,8 @@ package postgrescmd import ( "context" - "errors" "fmt" "io" - "strings" "time" "github.com/databricks/cli/cmd/root" @@ -26,6 +24,7 @@ type queryFlags struct { database string connectTimeout time.Duration maxRetries int + files []string // outputFormat is the raw flag value. resolveOutputFormat turns it into // the effective format (which may differ when stdout is piped). @@ -37,9 +36,9 @@ func newQueryCmd() *cobra.Command { var f queryFlags cmd := &cobra.Command{ - Use: "query [SQL]", - Short: "Run a SQL statement against a Lakebase Postgres endpoint", - Long: `Execute a single SQL statement against a Lakebase Postgres endpoint. + Use: "query [SQL | file.sql]...", + Short: "Run SQL statements against a Lakebase Postgres endpoint", + Long: `Execute one or more SQL statements against a Lakebase Postgres endpoint. Targeting (exactly one form required): --target STRING Provisioned instance name OR autoscaling resource path @@ -48,37 +47,43 @@ Targeting (exactly one form required): --branch ID Autoscaling branch ID (default: auto-select if exactly one) --endpoint ID Autoscaling endpoint ID +Inputs (positionals and --file may be combined; execution order is files-first +then positionals; stdin is used only when neither is present): + -f, --file PATH SQL file path (repeatable). Each file must contain + exactly one statement. + positional SQL string OR path ending in '.sql' that exists on disk. + Output: --output text Aligned table for rows-producing statements (default). Falls back to JSON when stdout is not a terminal so scripts piping the output get machine-readable results. - --output json Top-level array of row objects, streamed for - rows-producing statements. Command-only statements - emit a single {"command": "...", "rows_affected": N} - object. Numbers, booleans, NULL, jsonb, timestamps - render with their JSON-native types. + --output json For a single input: top-level array of row objects, + streamed. For multiple inputs: top-level array of + per-unit result objects ({"sql","kind","elapsed_ms",...}), + with each object buffered to completion. --output csv Header row + one CSV row per result row, streamed. - Command-only statements write the command tag to - stderr. + Single-input only; multi-input + csv is rejected + pre-flight. Use --output json for multi-input. DATABRICKS_OUTPUT_FORMAT is honoured when --output is not explicitly set. -This is an experimental command. The flag set, output shape, and supported -target kinds will expand in subsequent releases. - Limitations (this release): - - Single SQL statement per invocation (multi-statement support comes later). + - Single statement per input unit. Multi-statement strings (e.g. + "SELECT 1; SELECT 2") are rejected; pass each as a separate positional + or --file. - No interactive REPL. 'databricks psql' continues to own that surface. - - Multi-statement strings (e.g. "SELECT 1; SELECT 2") are not supported. + - Inputs run sequentially on one connection; session state (SET, temp + tables, prepared statement names) carries across them. - The OAuth token is generated once per invocation and is valid for 1h. Queries longer than that fail with an auth error. + - --output csv is rejected when more than one input unit is present; + use --output json or split into separate invocations. `, - Args: cobra.ExactArgs(1), PreRunE: root.MustWorkspaceClient, RunE: func(cmd *cobra.Command, args []string) error { f.outputFormatSet = cmd.Flag("output").Changed - return runQuery(cmd.Context(), cmd, args[0], f) + return runQuery(cmd.Context(), cmd, args, f) }, } @@ -89,6 +94,7 @@ Limitations (this release): cmd.Flags().StringVarP(&f.database, "database", "d", defaultDatabase, "Database name") cmd.Flags().DurationVar(&f.connectTimeout, "connect-timeout", defaultConnectTimeout, "Connect timeout") cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (must be >= 1; 1 disables retry)") + cmd.Flags().StringArrayVarP(&f.files, "file", "f", nil, "SQL file path (repeatable)") cmd.Flags().StringVarP(&f.outputFormat, "output", "o", string(outputText), "Output format: text, json, or csv") cmd.RegisterFlagCompletionFunc("output", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { out := make([]string, len(allOutputFormats)) @@ -108,11 +114,7 @@ Limitations (this release): // runQuery is the production entry point. It is split out from RunE so unit // tests can call it directly with a stubbed connectFunc once we add seam-based // tests in a later PR. -func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) error { - sql = strings.TrimSpace(sql) - if sql == "" { - return errors.New("no SQL provided") - } +func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFlags) error { if f.maxRetries < 1 { return fmt.Errorf("--max-retries must be at least 1; got %d", f.maxRetries) } @@ -120,17 +122,30 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) return err } - // SupportsColor is the public TTY-ish signal libs/cmdio exposes today; it - // also folds in NO_COLOR / TERM=dumb, which strictly speaking are colour - // preferences rather than TTY signals. Users who hit that edge case can - // pass --output text explicitly; that path is honoured (see - // resolveOutputFormat). Mirrors the aitools query command. + units, err := collectInputs(ctx, cmd.InOrStdin(), args, f.files) + if err != nil { + return err + } + for _, u := range units { + if err := checkSingleStatement(u.SQL); err != nil { + return fmt.Errorf("%s: %w%s", u.Source, err, multiStatementHint()) + } + } + stdoutTTY := cmdio.SupportsColor(ctx, cmd.OutOrStdout()) format, err := resolveOutputFormat(ctx, f.outputFormat, f.outputFormatSet, stdoutTTY) if err != nil { return err } + // CSV multi-input is rejected pre-flight: there is no sensible shape for + // a CSV that has to merge schemas across statements. The error names the + // flag pair and tells the user how to recover, per the repo rule about + // rejecting incompatible inputs early. + if format == outputCSV && len(units) > 1 { + return fmt.Errorf("--output csv requires a single input unit; got %d (use --output json for multi-input invocations)", len(units)) + } + resolved, err := resolveTarget(ctx, f.targetingFlags) if err != nil { return err @@ -162,8 +177,43 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) } defer conn.Close(context.WithoutCancel(ctx)) - sink := newSink(format, cmd.OutOrStdout(), cmd.ErrOrStderr()) - return executeOne(ctx, conn, sql, sink) + out := cmd.OutOrStdout() + stderr := cmd.ErrOrStderr() + + if len(units) == 1 { + // Single-input path: stream directly through the per-format sink. + // Avoids buffering rows for large exports and matches the v1 single- + // input behaviour PR 2 shipped. + sink := newSink(format, out, stderr) + return executeOne(ctx, conn, units[0].SQL, sink) + } + + // Multi-input path: per-unit buffering. The plan accepts this trade-off + // (multi-input invocations with huge SELECTs should use single-input + // invocations with --output csv for streaming). Sessions state (SET, + // temp tables) carries across units because we hold the same connection. + results := make([]*unitResult, 0, len(units)) + for _, u := range units { + r, err := runUnitBuffered(ctx, conn, u) + if err != nil { + // Render the successful prefix, then surface the error with + // rich pgError formatting if applicable. + if rerr := renderPartial(out, stderr, format, results, u, err); rerr != nil { + // Best-effort partial render failed; surface the original + // error to the user, the renderer error to debug logs. + fmt.Fprintln(stderr, "warning: failed to render partial result:", rerr) + } + return formatExecutionError(u.Source, err) + } + results = append(results, r) + } + + switch format { + case outputJSON: + return renderJSONMulti(out, stderr, results, -1, "") + default: + return renderTextMulti(out, results) + } } // newSink returns the rowSink for the chosen output format. Kept separate @@ -178,3 +228,38 @@ func newSink(format outputFormat, out, stderr io.Writer) rowSink { return newTextSink(out) } } + +// renderPartial emits the rendered output for the prefix of units that ran +// successfully before a unit errored. For multi-input json this also writes +// the error envelope as the last array element. +func renderPartial(out, stderr io.Writer, format outputFormat, results []*unitResult, errored inputUnit, err error) error { + switch format { + case outputJSON: + return renderJSONMulti(out, stderr, results, len(results), formatExecutionErrorMessage(errored.Source, err)) + default: + // Text: render whatever ran cleanly. The error message goes through + // cobra's default error path on stderr after we return. + return renderTextMulti(out, results) + } +} + +// formatExecutionError produces the error returned to cobra when an input +// unit failed. The message includes the source label so the user knows +// which of N inputs blew up. +func formatExecutionError(source string, err error) error { + return fmt.Errorf("%s: %s", source, formatPgError(err)) +} + +// formatExecutionErrorMessage is the string form of formatExecutionError, +// suitable for embedding in JSON envelopes. +func formatExecutionErrorMessage(source string, err error) string { + return fmt.Sprintf("%s: %s", source, formatPgError(err)) +} + +// multiStatementHint is the workaround pointer appended to the +// errMultipleStatements error so users see the recovery path inline. +func multiStatementHint() string { + return "\nThis command runs one statement per input. To run multiple statements:\n" + + ` - Pass each as a separate positional: query "SELECT 1" "SELECT 2"` + "\n" + + ` - Pass each in its own --file: query --file q1.sql --file q2.sql` +} diff --git a/experimental/postgres/cmd/render_multi.go b/experimental/postgres/cmd/render_multi.go new file mode 100644 index 00000000000..6ffc1e8d580 --- /dev/null +++ b/experimental/postgres/cmd/render_multi.go @@ -0,0 +1,209 @@ +package postgrescmd + +import ( + "bytes" + "fmt" + "io" + "strings" +) + +// renderTextMulti renders a sequence of unit results as plain text. Each +// per-unit block follows the single-input layout (table for rows-producing, +// command tag for command-only); successive blocks are separated by a blank +// line, mirroring `psql -c "...; ..."` shape. +// +// errIndex/errResult identifies the unit that errored (-1 if none); we still +// render any successful prefix. The error itself is surfaced by the caller +// via cobra's default error rendering. +func renderTextMulti(out io.Writer, results []*unitResult) error { + for i, r := range results { + if i > 0 { + if _, err := io.WriteString(out, "\n"); err != nil { + return err + } + } + if err := renderTextResult(out, r); err != nil { + return err + } + } + return nil +} + +// renderTextResult renders a single buffered unitResult in the same shape as +// textSink would for a streamed result. +func renderTextResult(out io.Writer, r *unitResult) error { + if !r.IsRowsProducing() { + _, err := fmt.Fprintln(out, r.CommandTag) + return err + } + + // Reuse textSink for the table layout so single-input and multi-input + // share the same alignment and footer logic. + sink := newTextSink(out) + if err := sink.Begin(r.Fields); err != nil { + return err + } + for _, row := range r.Rows { + if err := sink.Row(row); err != nil { + return err + } + } + return sink.End(r.CommandTag) +} + +// renderJSONMulti emits the wrapped multi-input JSON shape: a top-level +// array of result objects, one per input unit. Per-unit objects are buffered +// to completion before write; the outer array uses separator-before-element +// streaming. CSV multi-input is rejected pre-flight, so this function is +// only used for json. +// +// Per-unit shape: +// +// {"sql": "...", "kind": "rows", "elapsed_ms": N, "rows": [...]} +// {"sql": "...", "kind": "command", "elapsed_ms": N, "command": "...", "rows_affected": N} +// {"sql": "...", "kind": "error", "elapsed_ms": N, "error": {...}} +// +// kind discriminates which fields are present so consumers don't have to +// branch on key presence. +func renderJSONMulti(out, stderr io.Writer, results []*unitResult, errIndex int, errMessage string) error { + if _, err := io.WriteString(out, "[\n"); err != nil { + return err + } + + for i, r := range results { + if i > 0 { + if _, err := io.WriteString(out, ",\n"); err != nil { + return err + } + } + var unitBuf bytes.Buffer + if err := renderJSONUnit(&unitBuf, stderr, r); err != nil { + return err + } + if _, err := out.Write(unitBuf.Bytes()); err != nil { + return err + } + } + + if errIndex >= 0 { + // The errored unit follows the last successful unit; write a comma + // separator and the error envelope for it. + if len(results) > 0 { + if _, err := io.WriteString(out, ",\n"); err != nil { + return err + } + } + errSQL := "" + errSource := "" + // errIndex points to the input *unit* index; since we render + // successful units in order, the errored unit's SQL came from the + // caller's units slice. The caller embeds it in errMessage so we + // don't need separate plumbing here. + obj := jsonErrorObject(errSource, errSQL, errMessage) + if _, err := out.Write(obj); err != nil { + return err + } + } + + _, err := io.WriteString(out, "\n]\n") + return err +} + +// renderJSONUnit writes one buffered result object to buf, using the +// existing single-input json rendering for the rows array so the value +// mapping stays consistent across single- and multi-input shapes. +func renderJSONUnit(buf *bytes.Buffer, stderr io.Writer, r *unitResult) error { + if !r.IsRowsProducing() { + // Command-only unit. + if _, err := fmt.Fprintf(buf, `{"sql":`); err != nil { + return err + } + sqlJSON, err := marshalJSON(r.SQL) + if err != nil { + return err + } + buf.Write(sqlJSON) + fmt.Fprintf(buf, `,"kind":"command","elapsed_ms":%d`, r.Elapsed.Milliseconds()) + fmt.Fprintf(buf, `,"command":"%s"`, jsonEscapeShort(commandTagVerb(r.CommandTag))) + if rows, ok := commandTagRowCount(r.CommandTag); ok { + fmt.Fprintf(buf, `,"rows_affected":%d`, rows) + } + buf.WriteString(`}`) + return nil + } + + // Rows-producing unit. We reuse jsonSink for the rows array body so + // the per-row encoding (column order, type mapping) stays in one place. + if _, err := fmt.Fprintf(buf, `{"sql":`); err != nil { + return err + } + sqlJSON, err := marshalJSON(r.SQL) + if err != nil { + return err + } + buf.Write(sqlJSON) + fmt.Fprintf(buf, `,"kind":"rows","elapsed_ms":%d,"rows":`, r.Elapsed.Milliseconds()) + + rowsBuf := &bytes.Buffer{} + sink := newJSONSink(rowsBuf, stderr) + if err := sink.Begin(r.Fields); err != nil { + return err + } + for _, row := range r.Rows { + if err := sink.Row(row); err != nil { + return err + } + } + // Use a no-op tag for End so jsonSink's success path emits the closing + // bracket. The trailing newline gets trimmed below. + if err := sink.End(""); err != nil { + return err + } + rowsTrimmed := bytes.TrimRight(rowsBuf.Bytes(), "\n") + buf.Write(rowsTrimmed) + buf.WriteString(`}`) + return nil +} + +// jsonErrorObject builds the per-unit error envelope used in the multi-input +// JSON shape. message is the formatted error message (already includes +// SQLSTATE / hint / detail when applicable). +func jsonErrorObject(source, sql, message string) []byte { + var buf bytes.Buffer + buf.WriteString(`{"source":`) + if b, err := marshalJSON(source); err == nil { + buf.Write(b) + } else { + buf.WriteString(`""`) + } + buf.WriteString(`,"sql":`) + if b, err := marshalJSON(sql); err == nil { + buf.Write(b) + } else { + buf.WriteString(`""`) + } + buf.WriteString(`,"kind":"error","error":{"message":`) + if b, err := marshalJSON(message); err == nil { + buf.Write(b) + } else { + buf.WriteString(`""`) + } + buf.WriteString(`}}`) + return buf.Bytes() +} + +// jsonEscapeShort is a fast path for short ASCII strings (command tag verbs) +// that need backslash escapes for ", \, and control bytes only. Falls back +// to a string-escaped value if the input contains anything unusual. +func jsonEscapeShort(s string) string { + if !strings.ContainsAny(s, "\"\\\n\r\t") { + return s + } + out, err := marshalJSON(s) + if err != nil { + return s + } + // marshalJSON returns the value with surrounding quotes; strip them so + // the caller can wrap with its own quoting. + return string(bytes.Trim(out, `"`)) +} diff --git a/experimental/postgres/cmd/render_multi_test.go b/experimental/postgres/cmd/render_multi_test.go new file mode 100644 index 00000000000..dba5174a435 --- /dev/null +++ b/experimental/postgres/cmd/render_multi_test.go @@ -0,0 +1,89 @@ +package postgrescmd + +import ( + "bytes" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRenderTextMulti_TwoResults(t *testing.T) { + r1 := &unitResult{ + Source: "argv[1]", + SQL: "INSERT INTO t VALUES (1)", + CommandTag: "INSERT 0 1", + Elapsed: 5 * time.Millisecond, + } + r2 := &unitResult{ + Source: "argv[2]", + SQL: "SELECT id FROM t", + Fields: fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}), + Rows: [][]any{{int64(1)}}, + CommandTag: "SELECT 1", + Elapsed: 3 * time.Millisecond, + } + + var buf bytes.Buffer + require.NoError(t, renderTextMulti(&buf, []*unitResult{r1, r2})) + out := buf.String() + assert.Contains(t, out, "INSERT 0 1\n") + assert.Contains(t, out, "id") + assert.Contains(t, out, "(1 row)") + // Blank-line separator between blocks. + assert.Contains(t, out, "INSERT 0 1\n\n") +} + +func TestRenderJSONMulti_TwoResults(t *testing.T) { + r1 := &unitResult{ + Source: "argv[1]", + SQL: "INSERT INTO t VALUES (1)", + CommandTag: "INSERT 0 1", + Elapsed: 5 * time.Millisecond, + } + r2 := &unitResult{ + Source: "argv[2]", + SQL: "SELECT id FROM t", + Fields: fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}), + Rows: [][]any{{int64(1)}, {int64(2)}}, + CommandTag: "SELECT 2", + Elapsed: 3 * time.Millisecond, + } + + var stdout, stderr bytes.Buffer + require.NoError(t, renderJSONMulti(&stdout, &stderr, []*unitResult{r1, r2}, -1, "")) + + out := stdout.String() + assert.Contains(t, out, `"sql":"INSERT INTO t VALUES (1)"`) + assert.Contains(t, out, `"kind":"command"`) + assert.Contains(t, out, `"command":"INSERT"`) + assert.Contains(t, out, `"rows_affected":1`) + assert.Contains(t, out, `"sql":"SELECT id FROM t"`) + assert.Contains(t, out, `"kind":"rows"`) + assert.Contains(t, out, `"rows":`) + // Outer array framing. + assert.Greater(t, len(out), 4) + assert.Equal(t, byte('['), out[0]) + assert.Equal(t, byte('\n'), out[len(out)-1]) +} + +func TestRenderJSONMulti_WithErrorAtEnd(t *testing.T) { + r1 := &unitResult{ + Source: "argv[1]", + SQL: "SELECT 1", + Fields: fieldsWithOIDs([]string{"?column?"}, []uint32{pgtype.Int8OID}), + Rows: [][]any{{int64(1)}}, + CommandTag: "SELECT 1", + Elapsed: 1 * time.Millisecond, + } + + var stdout, stderr bytes.Buffer + require.NoError(t, renderJSONMulti(&stdout, &stderr, []*unitResult{r1}, 1, "argv[2]: ERROR: syntax error (SQLSTATE 42601)")) + + out := stdout.String() + assert.Contains(t, out, `"kind":"rows"`) + assert.Contains(t, out, `"kind":"error"`) + assert.Contains(t, out, `"message":"argv[2]: ERROR: syntax error (SQLSTATE 42601)"`) +} diff --git a/experimental/postgres/cmd/result.go b/experimental/postgres/cmd/result.go new file mode 100644 index 00000000000..d9b449a4847 --- /dev/null +++ b/experimental/postgres/cmd/result.go @@ -0,0 +1,62 @@ +package postgrescmd + +import ( + "context" + "fmt" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// unitResult is the buffered result of running one input unit. The +// multi-input renderers (text, json) need rows buffered before they can +// emit a per-unit block; for the single-input path we still stream +// directly through a rowSink and never produce a unitResult. +type unitResult struct { + Source string + SQL string + Fields []pgconn.FieldDescription + Rows [][]any + CommandTag string + Elapsed time.Duration +} + +// IsRowsProducing returns whether the unit returned a row description. +func (r *unitResult) IsRowsProducing() bool { + return len(r.Fields) > 0 +} + +// runUnitBuffered runs sql and collects every row into memory. Used by the +// multi-input output paths (text and json), where per-unit buffering is +// acceptable per the plan: a multi-input invocation that emits a huge +// SELECT will buffer that result before printing. Users with huge result +// sets per statement should use single-input invocations (which fully +// stream) or --output csv on a single input. +func runUnitBuffered(ctx context.Context, conn *pgx.Conn, unit inputUnit) (*unitResult, error) { + start := time.Now() + rows, err := conn.Query(ctx, unit.SQL, pgx.QueryExecModeExec) + if err != nil { + return nil, fmt.Errorf("query failed: %w", err) + } + defer rows.Close() + + r := &unitResult{ + Source: unit.Source, + SQL: unit.SQL, + Fields: rows.FieldDescriptions(), + } + for rows.Next() { + values, err := rows.Values() + if err != nil { + return nil, fmt.Errorf("decode row: %w", err) + } + r.Rows = append(r.Rows, values) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("query failed: %w", err) + } + r.CommandTag = rows.CommandTag().String() + r.Elapsed = time.Since(start) + return r, nil +} From 78357bdb760b664ba98592d209d8d0aec34efc81 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 11:41:57 +0200 Subject: [PATCH 09/25] Address PR 3 review feedback round 1 MUSTs: - Multi-input JSON error envelope: thread the failing *unitResult into renderJSONMulti so source/sql/elapsed_ms reflect the actual failing input instead of empty strings. - Canonical key order for every per-unit object: {"source", "sql", "kind", "elapsed_ms", payload} Success and error envelopes now share the same shape so consumers don't have to special-case kind=="error" for missing fields. SHOULDs: - Single-input path now goes through formatPgError, so DETAIL/HINT surface consistently across single- and multi-input. - runUnitBuffered reuses executeOne via a new bufferSink. The two query loops collapse to one; future error-handling changes auto- propagate. - Scanner: reject `$...` as a dollar-quote tag (PG docs forbid digit-leading tags). Pinned with a test for `SELECT $1, $2 FROM t` and `SELECT $1 FROM t; SELECT 2`. - Pin the E-string over-rejection behaviour with a test, so a future scanner improvement has to update the assertion. CONSIDERs: - Capture elapsed_ms on the error path too (was previously discarded). - Promote multiStatementHint to a const. - Drop jsonEscapeShort (was a fragile micro-opt for an always-ASCII domain); use marshalJSON for the command verb instead. - Add TestRenderJSONMulti_FirstUnitFails to pin the empty-success- prefix framing. Co-authored-by: Isaac --- experimental/postgres/cmd/multistatement.go | 13 +- .../postgres/cmd/multistatement_test.go | 12 ++ experimental/postgres/cmd/query.go | 39 +++--- experimental/postgres/cmd/render_multi.go | 125 ++++++++---------- .../postgres/cmd/render_multi_test.go | 38 ++++-- experimental/postgres/cmd/result.go | 63 +++++---- 6 files changed, 157 insertions(+), 133 deletions(-) diff --git a/experimental/postgres/cmd/multistatement.go b/experimental/postgres/cmd/multistatement.go index 4c4f976e8e7..4bfedbbfaba 100644 --- a/experimental/postgres/cmd/multistatement.go +++ b/experimental/postgres/cmd/multistatement.go @@ -127,9 +127,12 @@ func scanBlockComment(s string, start, end int) int { // '$' of $tag$. If the construct doesn't look like a valid dollar-quote // opener, returns ("", start) so the caller can fall through. // -// Tag rule: starts after '$', runs to the next '$', and must consist of -// letter-or-underscore-or-digit (we accept all non-special bytes; over- -// permissive). Empty tag is valid: $$ is a marker, $$body$$ is the body. +// Tag rule: starts after '$', runs to the next '$'. Per the Postgres docs a +// dollar-quote tag must not start with a digit, so we reject `$1`, `$2`, +// etc. as tags and let the scanner treat them as ordinary bytes (this is +// what `$1`-style parameter placeholders look like, even though `QueryExecModeExec` +// can't bind them in this command). Empty tag is valid: $$ is a marker, +// $$body$$ is the body. func readDollarTag(s string, start, end int) (string, int) { i := start + 1 for i < end { @@ -137,6 +140,10 @@ func readDollarTag(s string, start, end int) (string, int) { tag := s[start+1 : i] return tag, i + 1 } + // Reject `$...` early: it can't be a valid tag. + if i == start+1 && s[i] >= '0' && s[i] <= '9' { + return "", start + } // Stop at characters that can't be in a tag. if s[i] == ' ' || s[i] == '\t' || s[i] == '\n' || s[i] == ';' { return "", start diff --git a/experimental/postgres/cmd/multistatement_test.go b/experimental/postgres/cmd/multistatement_test.go index bb50bf5e8ef..ae60ee7e15f 100644 --- a/experimental/postgres/cmd/multistatement_test.go +++ b/experimental/postgres/cmd/multistatement_test.go @@ -39,6 +39,18 @@ func TestCheckSingleStatement(t *testing.T) { {name: "empty input", input: "", wantErr: false}, {name: "only whitespace", input: " \n\t ", wantErr: false}, {name: "only semicolon", input: ";", wantErr: false}, + + // $1 / $2 placeholder syntax must not be confused with a dollar-quote + // tag (tags can't start with a digit per PG docs). + {name: "dollar-digit placeholders", input: "SELECT $1, $2 FROM t", wantErr: false}, + {name: "dollar-digit then real semi", input: "SELECT $1 FROM t; SELECT 2", wantErr: true}, + + // E-string escape syntax: scanner doesn't honour \' escape, so a + // backslash-escaped apostrophe terminates the literal early. We + // document the over-rejection rather than fix it (acceptable v1 + // stance per the plan); pin the behaviour here so the next person + // touching the scanner has to update the test. + {name: "E-string with backslash-escape over-rejects", input: `SELECT E'foo\';bar'`, wantErr: true}, } for _, tc := range tests { diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index bc339c29d43..4bd75c8a71f 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -2,6 +2,7 @@ package postgrescmd import ( "context" + "errors" "fmt" "io" "time" @@ -128,7 +129,7 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla } for _, u := range units { if err := checkSingleStatement(u.SQL); err != nil { - return fmt.Errorf("%s: %w%s", u.Source, err, multiStatementHint()) + return fmt.Errorf("%s: %w%s", u.Source, err, multiStatementHint) } } @@ -183,14 +184,18 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla if len(units) == 1 { // Single-input path: stream directly through the per-format sink. // Avoids buffering rows for large exports and matches the v1 single- - // input behaviour PR 2 shipped. + // input behaviour PR 2 shipped. Wrap the error so DETAIL / HINT + // from a *pgconn.PgError surface even on the single-input path. sink := newSink(format, out, stderr) - return executeOne(ctx, conn, units[0].SQL, sink) + if err := executeOne(ctx, conn, units[0].SQL, sink); err != nil { + return errors.New(formatPgError(err)) + } + return nil } // Multi-input path: per-unit buffering. The plan accepts this trade-off // (multi-input invocations with huge SELECTs should use single-input - // invocations with --output csv for streaming). Sessions state (SET, + // invocations with --output csv for streaming). Session state (SET, // temp tables) carries across units because we hold the same connection. results := make([]*unitResult, 0, len(units)) for _, u := range units { @@ -198,7 +203,7 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla if err != nil { // Render the successful prefix, then surface the error with // rich pgError formatting if applicable. - if rerr := renderPartial(out, stderr, format, results, u, err); rerr != nil { + if rerr := renderPartial(out, stderr, format, results, r, err); rerr != nil { // Best-effort partial render failed; surface the original // error to the user, the renderer error to debug logs. fmt.Fprintln(stderr, "warning: failed to render partial result:", rerr) @@ -210,7 +215,7 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla switch format { case outputJSON: - return renderJSONMulti(out, stderr, results, -1, "") + return renderJSONMulti(out, stderr, results, nil, "") default: return renderTextMulti(out, results) } @@ -232,10 +237,10 @@ func newSink(format outputFormat, out, stderr io.Writer) rowSink { // renderPartial emits the rendered output for the prefix of units that ran // successfully before a unit errored. For multi-input json this also writes // the error envelope as the last array element. -func renderPartial(out, stderr io.Writer, format outputFormat, results []*unitResult, errored inputUnit, err error) error { +func renderPartial(out, stderr io.Writer, format outputFormat, results []*unitResult, errored *unitResult, err error) error { switch format { case outputJSON: - return renderJSONMulti(out, stderr, results, len(results), formatExecutionErrorMessage(errored.Source, err)) + return renderJSONMulti(out, stderr, results, errored, formatPgError(err)) default: // Text: render whatever ran cleanly. The error message goes through // cobra's default error path on stderr after we return. @@ -250,16 +255,8 @@ func formatExecutionError(source string, err error) error { return fmt.Errorf("%s: %s", source, formatPgError(err)) } -// formatExecutionErrorMessage is the string form of formatExecutionError, -// suitable for embedding in JSON envelopes. -func formatExecutionErrorMessage(source string, err error) string { - return fmt.Sprintf("%s: %s", source, formatPgError(err)) -} - -// multiStatementHint is the workaround pointer appended to the -// errMultipleStatements error so users see the recovery path inline. -func multiStatementHint() string { - return "\nThis command runs one statement per input. To run multiple statements:\n" + - ` - Pass each as a separate positional: query "SELECT 1" "SELECT 2"` + "\n" + - ` - Pass each in its own --file: query --file q1.sql --file q2.sql` -} +// multiStatementHint is appended to errMultipleStatements so users see the +// recovery path inline. +const multiStatementHint = "\nThis command runs one statement per input. To run multiple statements:\n" + + ` - Pass each as a separate positional: query "SELECT 1" "SELECT 2"` + "\n" + + ` - Pass each in its own --file: query --file q1.sql --file q2.sql` diff --git a/experimental/postgres/cmd/render_multi.go b/experimental/postgres/cmd/render_multi.go index 6ffc1e8d580..4cfa2063f72 100644 --- a/experimental/postgres/cmd/render_multi.go +++ b/experimental/postgres/cmd/render_multi.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "io" - "strings" ) // renderTextMulti renders a sequence of unit results as plain text. Each @@ -57,15 +56,19 @@ func renderTextResult(out io.Writer, r *unitResult) error { // streaming. CSV multi-input is rejected pre-flight, so this function is // only used for json. // -// Per-unit shape: +// Every per-unit object shares the same canonical key order: // -// {"sql": "...", "kind": "rows", "elapsed_ms": N, "rows": [...]} -// {"sql": "...", "kind": "command", "elapsed_ms": N, "command": "...", "rows_affected": N} -// {"sql": "...", "kind": "error", "elapsed_ms": N, "error": {...}} +// {"source", "sql", "kind", "elapsed_ms", payload...} // -// kind discriminates which fields are present so consumers don't have to -// branch on key presence. -func renderJSONMulti(out, stderr io.Writer, results []*unitResult, errIndex int, errMessage string) error { +// where payload depends on kind: +// +// "rows": {..., "rows": [...]} +// "command": {..., "command": "...", "rows_affected": N} +// "error": {..., "error": {"message": "..."}} +// +// elapsed_ms is present on errors too: it captures how long the failing +// statement ran before the error fired. +func renderJSONMulti(out, stderr io.Writer, results []*unitResult, errored *unitResult, errMessage string) error { if _, err := io.WriteString(out, "[\n"); err != nil { return err } @@ -85,21 +88,13 @@ func renderJSONMulti(out, stderr io.Writer, results []*unitResult, errIndex int, } } - if errIndex >= 0 { - // The errored unit follows the last successful unit; write a comma - // separator and the error envelope for it. + if errored != nil { if len(results) > 0 { if _, err := io.WriteString(out, ",\n"); err != nil { return err } } - errSQL := "" - errSource := "" - // errIndex points to the input *unit* index; since we render - // successful units in order, the errored unit's SQL came from the - // caller's units slice. The caller embeds it in errMessage so we - // don't need separate plumbing here. - obj := jsonErrorObject(errSource, errSQL, errMessage) + obj := jsonErrorObject(errored, errMessage) if _, err := out.Write(obj); err != nil { return err } @@ -113,18 +108,19 @@ func renderJSONMulti(out, stderr io.Writer, results []*unitResult, errIndex int, // existing single-input json rendering for the rows array so the value // mapping stays consistent across single- and multi-input shapes. func renderJSONUnit(buf *bytes.Buffer, stderr io.Writer, r *unitResult) error { + if err := writeJSONUnitHeader(buf, r); err != nil { + return err + } + if !r.IsRowsProducing() { - // Command-only unit. - if _, err := fmt.Fprintf(buf, `{"sql":`); err != nil { - return err - } - sqlJSON, err := marshalJSON(r.SQL) + buf.WriteString(`,"kind":"command"`) + fmt.Fprintf(buf, `,"elapsed_ms":%d`, r.Elapsed.Milliseconds()) + verbBytes, err := marshalJSON(commandTagVerb(r.CommandTag)) if err != nil { return err } - buf.Write(sqlJSON) - fmt.Fprintf(buf, `,"kind":"command","elapsed_ms":%d`, r.Elapsed.Milliseconds()) - fmt.Fprintf(buf, `,"command":"%s"`, jsonEscapeShort(commandTagVerb(r.CommandTag))) + buf.WriteString(`,"command":`) + buf.Write(verbBytes) if rows, ok := commandTagRowCount(r.CommandTag); ok { fmt.Fprintf(buf, `,"rows_affected":%d`, rows) } @@ -132,17 +128,10 @@ func renderJSONUnit(buf *bytes.Buffer, stderr io.Writer, r *unitResult) error { return nil } - // Rows-producing unit. We reuse jsonSink for the rows array body so - // the per-row encoding (column order, type mapping) stays in one place. - if _, err := fmt.Fprintf(buf, `{"sql":`); err != nil { - return err - } - sqlJSON, err := marshalJSON(r.SQL) - if err != nil { - return err - } - buf.Write(sqlJSON) - fmt.Fprintf(buf, `,"kind":"rows","elapsed_ms":%d,"rows":`, r.Elapsed.Milliseconds()) + // Rows-producing unit. Reuse jsonSink for the rows array body so the + // per-row encoding (column order, type mapping) stays in one place. + buf.WriteString(`,"kind":"rows"`) + fmt.Fprintf(buf, `,"elapsed_ms":%d,"rows":`, r.Elapsed.Milliseconds()) rowsBuf := &bytes.Buffer{} sink := newJSONSink(rowsBuf, stderr) @@ -154,8 +143,6 @@ func renderJSONUnit(buf *bytes.Buffer, stderr io.Writer, r *unitResult) error { return err } } - // Use a no-op tag for End so jsonSink's success path emits the closing - // bracket. The trailing newline gets trimmed below. if err := sink.End(""); err != nil { return err } @@ -165,24 +152,38 @@ func renderJSONUnit(buf *bytes.Buffer, stderr io.Writer, r *unitResult) error { return nil } -// jsonErrorObject builds the per-unit error envelope used in the multi-input -// JSON shape. message is the formatted error message (already includes -// SQLSTATE / hint / detail when applicable). -func jsonErrorObject(source, sql, message string) []byte { - var buf bytes.Buffer - buf.WriteString(`{"source":`) - if b, err := marshalJSON(source); err == nil { - buf.Write(b) - } else { - buf.WriteString(`""`) +// writeJSONUnitHeader writes the canonical {source, sql, ...} prefix used +// by every per-unit object. The closing brace and the kind-specific payload +// are appended by the caller. +func writeJSONUnitHeader(buf *bytes.Buffer, r *unitResult) error { + sourceBytes, err := marshalJSON(r.Source) + if err != nil { + return err + } + sqlBytes, err := marshalJSON(r.SQL) + if err != nil { + return err } + buf.WriteString(`{"source":`) + buf.Write(sourceBytes) buf.WriteString(`,"sql":`) - if b, err := marshalJSON(sql); err == nil { - buf.Write(b) - } else { - buf.WriteString(`""`) + buf.Write(sqlBytes) + return nil +} + +// jsonErrorObject builds the per-unit error envelope used in the multi-input +// JSON shape. The buffered unitResult provides source, SQL, and the elapsed +// time captured by runUnitBuffered's error path. message is the +// already-formatted error wording (includes SQLSTATE / hint / detail for +// PgErrors). +func jsonErrorObject(r *unitResult, message string) []byte { + var buf bytes.Buffer + if err := writeJSONUnitHeader(&buf, r); err != nil { + return []byte(`{"source":"","sql":"","kind":"error","elapsed_ms":0,"error":{"message":""}}`) } - buf.WriteString(`,"kind":"error","error":{"message":`) + buf.WriteString(`,"kind":"error"`) + fmt.Fprintf(&buf, `,"elapsed_ms":%d`, r.Elapsed.Milliseconds()) + buf.WriteString(`,"error":{"message":`) if b, err := marshalJSON(message); err == nil { buf.Write(b) } else { @@ -191,19 +192,3 @@ func jsonErrorObject(source, sql, message string) []byte { buf.WriteString(`}}`) return buf.Bytes() } - -// jsonEscapeShort is a fast path for short ASCII strings (command tag verbs) -// that need backslash escapes for ", \, and control bytes only. Falls back -// to a string-escaped value if the input contains anything unusual. -func jsonEscapeShort(s string) string { - if !strings.ContainsAny(s, "\"\\\n\r\t") { - return s - } - out, err := marshalJSON(s) - if err != nil { - return s - } - // marshalJSON returns the value with surrounding quotes; strip them so - // the caller can wrap with its own quoting. - return string(bytes.Trim(out, `"`)) -} diff --git a/experimental/postgres/cmd/render_multi_test.go b/experimental/postgres/cmd/render_multi_test.go index dba5174a435..b4e96f73eb8 100644 --- a/experimental/postgres/cmd/render_multi_test.go +++ b/experimental/postgres/cmd/render_multi_test.go @@ -53,16 +53,12 @@ func TestRenderJSONMulti_TwoResults(t *testing.T) { } var stdout, stderr bytes.Buffer - require.NoError(t, renderJSONMulti(&stdout, &stderr, []*unitResult{r1, r2}, -1, "")) + require.NoError(t, renderJSONMulti(&stdout, &stderr, []*unitResult{r1, r2}, nil, "")) out := stdout.String() - assert.Contains(t, out, `"sql":"INSERT INTO t VALUES (1)"`) - assert.Contains(t, out, `"kind":"command"`) - assert.Contains(t, out, `"command":"INSERT"`) - assert.Contains(t, out, `"rows_affected":1`) - assert.Contains(t, out, `"sql":"SELECT id FROM t"`) - assert.Contains(t, out, `"kind":"rows"`) - assert.Contains(t, out, `"rows":`) + // Canonical key order: source, sql, kind, elapsed_ms, payload. + assert.Contains(t, out, `"source":"argv[1]","sql":"INSERT INTO t VALUES (1)","kind":"command","elapsed_ms":5,"command":"INSERT","rows_affected":1`) + assert.Contains(t, out, `"source":"argv[2]","sql":"SELECT id FROM t","kind":"rows","elapsed_ms":3,"rows":`) // Outer array framing. assert.Greater(t, len(out), 4) assert.Equal(t, byte('['), out[0]) @@ -78,12 +74,32 @@ func TestRenderJSONMulti_WithErrorAtEnd(t *testing.T) { CommandTag: "SELECT 1", Elapsed: 1 * time.Millisecond, } + errored := &unitResult{ + Source: "argv[2]", + SQL: "BROKEN SQL", + Elapsed: 2 * time.Millisecond, + } var stdout, stderr bytes.Buffer - require.NoError(t, renderJSONMulti(&stdout, &stderr, []*unitResult{r1}, 1, "argv[2]: ERROR: syntax error (SQLSTATE 42601)")) + require.NoError(t, renderJSONMulti(&stdout, &stderr, []*unitResult{r1}, errored, "ERROR: syntax error (SQLSTATE 42601)")) out := stdout.String() assert.Contains(t, out, `"kind":"rows"`) - assert.Contains(t, out, `"kind":"error"`) - assert.Contains(t, out, `"message":"argv[2]: ERROR: syntax error (SQLSTATE 42601)"`) + // Error envelope: same key order, includes elapsed_ms + source + sql. + assert.Contains(t, out, `"source":"argv[2]","sql":"BROKEN SQL","kind":"error","elapsed_ms":2,"error":{"message":"ERROR: syntax error (SQLSTATE 42601)"}`) +} + +func TestRenderJSONMulti_FirstUnitFails(t *testing.T) { + errored := &unitResult{ + Source: "argv[1]", + SQL: "BROKEN", + Elapsed: 7 * time.Millisecond, + } + var stdout, stderr bytes.Buffer + require.NoError(t, renderJSONMulti(&stdout, &stderr, nil, errored, "ERROR: bad")) + + out := stdout.String() + // No leading separator before the single error envelope. + assert.Contains(t, out, "[\n"+`{"source":"argv[1]","sql":"BROKEN","kind":"error","elapsed_ms":7,"error":{"message":"ERROR: bad"}}`) + assert.Contains(t, out, "\n]\n") } diff --git a/experimental/postgres/cmd/result.go b/experimental/postgres/cmd/result.go index d9b449a4847..ec03534bfb8 100644 --- a/experimental/postgres/cmd/result.go +++ b/experimental/postgres/cmd/result.go @@ -2,7 +2,6 @@ package postgrescmd import ( "context" - "fmt" "time" "github.com/jackc/pgx/v5" @@ -27,36 +26,44 @@ func (r *unitResult) IsRowsProducing() bool { return len(r.Fields) > 0 } -// runUnitBuffered runs sql and collects every row into memory. Used by the -// multi-input output paths (text and json), where per-unit buffering is -// acceptable per the plan: a multi-input invocation that emits a huge -// SELECT will buffer that result before printing. Users with huge result -// sets per statement should use single-input invocations (which fully -// stream) or --output csv on a single input. +// runUnitBuffered runs sql and collects every row into memory. Implemented +// as a thin wrapper that hands a bufferSink to executeOne, so error wrapping +// and the rowSink contract stay in one place rather than parallel-evolving +// across two query loops. func runUnitBuffered(ctx context.Context, conn *pgx.Conn, unit inputUnit) (*unitResult, error) { start := time.Now() - rows, err := conn.Query(ctx, unit.SQL, pgx.QueryExecModeExec) - if err != nil { - return nil, fmt.Errorf("query failed: %w", err) + r := &unitResult{Source: unit.Source, SQL: unit.SQL} + sink := &bufferSink{result: r} + if err := executeOne(ctx, conn, unit.SQL, sink); err != nil { + // Capture timing on the error path too so the JSON error envelope + // can surface "this query ran for X seconds before failing". + r.Elapsed = time.Since(start) + return r, err } - defer rows.Close() - - r := &unitResult{ - Source: unit.Source, - SQL: unit.SQL, - Fields: rows.FieldDescriptions(), - } - for rows.Next() { - values, err := rows.Values() - if err != nil { - return nil, fmt.Errorf("decode row: %w", err) - } - r.Rows = append(r.Rows, values) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("query failed: %w", err) - } - r.CommandTag = rows.CommandTag().String() r.Elapsed = time.Since(start) return r, nil } + +// bufferSink is a rowSink that copies fields, rows, and the command tag into +// a unitResult instead of writing anywhere. Used by the multi-input path so +// successive units can be rendered together once they're all available. +type bufferSink struct { + result *unitResult +} + +func (s *bufferSink) Begin(fields []pgconn.FieldDescription) error { + s.result.Fields = fields + return nil +} + +func (s *bufferSink) Row(values []any) error { + s.result.Rows = append(s.result.Rows, values) + return nil +} + +func (s *bufferSink) End(commandTag string) error { + s.result.CommandTag = commandTag + return nil +} + +func (s *bufferSink) OnError(err error) {} From 5f193b40dff126619cfa5bc14467f8fc4e7fcbc9 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 11:47:23 +0200 Subject: [PATCH 10/25] PR 3 r2: drop unreachable json-encoding fallback branches Round-2 reviewer noted jsonErrorObject's defensive branches around writeJSONUnitHeader/marshalJSON are unreachable (encoding/json doesn't error on string inputs), and the repo rule says drop "just in case" fallbacks. Replace with panic-on-impossible helpers. Co-authored-by: Isaac --- experimental/postgres/cmd/render_multi.go | 32 +++++++++++++++++------ 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/experimental/postgres/cmd/render_multi.go b/experimental/postgres/cmd/render_multi.go index 4cfa2063f72..2a2d7938163 100644 --- a/experimental/postgres/cmd/render_multi.go +++ b/experimental/postgres/cmd/render_multi.go @@ -176,19 +176,35 @@ func writeJSONUnitHeader(buf *bytes.Buffer, r *unitResult) error { // time captured by runUnitBuffered's error path. message is the // already-formatted error wording (includes SQLSTATE / hint / detail for // PgErrors). +// +// marshalJSON of a string never errors (encoding/json replaces invalid UTF-8 +// with U+FFFD), so the inner errors are unreachable and we treat them as +// programming errors via panic. func jsonErrorObject(r *unitResult, message string) []byte { var buf bytes.Buffer - if err := writeJSONUnitHeader(&buf, r); err != nil { - return []byte(`{"source":"","sql":"","kind":"error","elapsed_ms":0,"error":{"message":""}}`) - } + mustWriteJSONHeader(&buf, r) buf.WriteString(`,"kind":"error"`) fmt.Fprintf(&buf, `,"elapsed_ms":%d`, r.Elapsed.Milliseconds()) buf.WriteString(`,"error":{"message":`) - if b, err := marshalJSON(message); err == nil { - buf.Write(b) - } else { - buf.WriteString(`""`) - } + buf.Write(mustMarshalJSON(message)) buf.WriteString(`}}`) return buf.Bytes() } + +// mustWriteJSONHeader is writeJSONUnitHeader with a panic instead of an +// error return. The only failure mode is an unreachable encoding/json error. +func mustWriteJSONHeader(buf *bytes.Buffer, r *unitResult) { + if err := writeJSONUnitHeader(buf, r); err != nil { + panic(fmt.Errorf("encoding json header: %w", err)) + } +} + +// mustMarshalJSON is marshalJSON with a panic instead of an error return, +// for the same reason. +func mustMarshalJSON(v any) []byte { + b, err := marshalJSON(v) + if err != nil { + panic(fmt.Errorf("encoding json value: %w", err)) + } + return b +} From d0fe9b8d7344cc6ace3c5dab60085b30c76f2499 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 11:51:53 +0200 Subject: [PATCH 11/25] Cancellation watcher + --timeout + TUI for >30 rows + integration tests This is PR 4 of the experimental postgres query stack and finishes the plan's v1 scope. Cancellation watcher: pgx.ConnConfig now installs a CancelRequestContextWatcherHandler with CancelRequestDelay=0 (send the cancel-request immediately on ctx cancel) and DeadlineDelay=5s (fall back to deadlining the connection if the cancel-request hasn't terminated the query within 5s). Without this, Ctrl+C tears down the TCP connection but leaves the server-side query running until it next writes. Signal handling: a per-invocation signal goroutine watches SIGINT and SIGTERM, cancelling the connection-scoped ctx. The defer'd stop drains the signal channel so a queued signal during shutdown does not leak. On Windows, Go's console-control-handler routes Ctrl+C to os.Interrupt, so the same code path covers the Windows runner. --timeout: per-statement deadline applied via context.WithTimeout. A fresh deadline starts for each input unit; the connection-scoped ctx remains the parent so a SIGINT during unit N immediately cancels both. reportCancellation distinguishes the three error sources (ctx.Canceled, ctx.DeadlineExceeded, plain pg error) so the user-visible message is "Query cancelled.", "Query timed out after Xs.", or the formatted pg error respectively. TUI for >30 rows: when --output text and stdout is a prompt-capable TTY, results larger than staticTableThreshold (=30, matching aitools) hand off to libs/tableview's interactive viewer. Smaller results stay in the static tabwriter path so non-interactive callers see no change. Integration tests live in integration/cmd/postgres/. Skipped unless DATABRICKS_POSTGRES_INTEGRATION_TARGET is set; covers single-input JSON, command-only, --timeout firing, multi-input JSON, and a CSV streaming smoke test (generate_series(1, 100)). Ctrl+C is documented as needing a separate harness because it requires a child process. Co-authored-by: Isaac --- experimental/postgres/cmd/cancel_test.go | 73 +++++++++++++++ experimental/postgres/cmd/connect.go | 22 +++++ experimental/postgres/cmd/query.go | 78 ++++++++++++---- experimental/postgres/cmd/render.go | 29 +++++- experimental/postgres/cmd/signals.go | 40 ++++++++ integration/cmd/postgres/postgres_test.go | 109 ++++++++++++++++++++++ 6 files changed, 332 insertions(+), 19 deletions(-) create mode 100644 experimental/postgres/cmd/cancel_test.go create mode 100644 experimental/postgres/cmd/signals.go create mode 100644 integration/cmd/postgres/postgres_test.go diff --git a/experimental/postgres/cmd/cancel_test.go b/experimental/postgres/cmd/cancel_test.go new file mode 100644 index 00000000000..73de49ef6bb --- /dev/null +++ b/experimental/postgres/cmd/cancel_test.go @@ -0,0 +1,73 @@ +package postgrescmd + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestWithStatementTimeout_ZeroIsPassthrough(t *testing.T) { + parent := t.Context() + got, cancel := withStatementTimeout(parent, 0) + defer cancel() + // Parent and got should compare equal: zero timeout returns the parent + // unchanged (and a no-op cancel). + deadline, ok := got.Deadline() + assert.False(t, ok, "deadline should not be set when timeout is 0") + assert.True(t, deadline.IsZero()) +} + +func TestWithStatementTimeout_AppliesDeadline(t *testing.T) { + parent := t.Context() + got, cancel := withStatementTimeout(parent, time.Second) + defer cancel() + deadline, ok := got.Deadline() + assert.True(t, ok) + assert.False(t, deadline.IsZero()) +} + +func TestReportCancellation_SignalCanceled(t *testing.T) { + signalCtx, signalCancel := context.WithCancel(t.Context()) + signalCancel() + stmtCtx := signalCtx + got := reportCancellation(signalCtx, stmtCtx, errors.New("anything"), 0) + assert.Equal(t, "Query cancelled.", got) +} + +func TestReportCancellation_TimeoutFired(t *testing.T) { + signalCtx := t.Context() + stmtCtx, stmtCancel := context.WithDeadline(signalCtx, time.Now().Add(-time.Second)) + defer stmtCancel() + // Wait for the deadline to be surfaced. + <-stmtCtx.Done() + got := reportCancellation(signalCtx, stmtCtx, errors.New("query failed"), 5*time.Second) + assert.Equal(t, "Query timed out after 5s.", got) +} + +func TestReportCancellation_GenericError(t *testing.T) { + signalCtx := t.Context() + stmtCtx := signalCtx + got := reportCancellation(signalCtx, stmtCtx, errors.New("syntax error"), 0) + assert.Equal(t, "syntax error", got) +} + +func TestWatchInterruptSignals_CancelOnStop(t *testing.T) { + // stop should cancel the parent context as a side-effect so the goroutine + // terminates promptly. We don't actually send a SIGINT here (it would + // also kill the test runner); we just verify stop cleans up. + parent, parentCancel := context.WithCancel(t.Context()) + defer parentCancel() + + cancelled := false + cancel := func() { + cancelled = true + parentCancel() + } + + stop := watchInterruptSignals(parent, cancel) + stop() + assert.True(t, cancelled, "stop should call cancel to wake the goroutine") +} diff --git a/experimental/postgres/cmd/connect.go b/experimental/postgres/cmd/connect.go index 2eefc681868..cf7457aa6d1 100644 --- a/experimental/postgres/cmd/connect.go +++ b/experimental/postgres/cmd/connect.go @@ -11,6 +11,7 @@ import ( "github.com/databricks/cli/libs/log" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" ) // defaultConnectTimeout is the dial timeout for a single connect attempt. @@ -52,6 +53,19 @@ type connectFunc func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, erro // in the resolved values. The DSN-then-patch pattern is the recommended way // to configure pgx for `sslmode=require` because building a pgx.ConnConfig // by hand omits internal fields that the parser sets. +// +// The context-watcher handler is overridden so context cancellation issues +// a Postgres CancelRequest on the side-channel rather than only closing the +// underlying TCP connection. Without this override, a Ctrl+C during a long +// SELECT would tear down the TCP socket but leave the server-side query +// running until it noticed the broken connection on its next write. +// +// CancelRequestDelay = 0: send the cancel-request immediately on ctx cancel. +// The user just hit Ctrl+C; we want the server to learn now. +// DeadlineDelay = 5s: if the cancel-request has not gotten the server to +// terminate the query within 5s, fall back to deadlining the connection. +// Zero DeadlineDelay would race the cancel-request and could leave the +// connection unusable. func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) { cfg, err := pgx.ParseConfig("postgresql:///?sslmode=require") if err != nil { @@ -63,6 +77,14 @@ func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) { cfg.Password = c.Password cfg.Database = c.Database cfg.ConnectTimeout = c.ConnectTimeout + + cfg.BuildContextWatcherHandler = func(pgc *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: pgc, + CancelRequestDelay: 0, + DeadlineDelay: 5 * time.Second, + } + } return cfg, nil } diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index 4bd75c8a71f..8f5e7c5ab66 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -26,6 +26,7 @@ type queryFlags struct { connectTimeout time.Duration maxRetries int files []string + timeout time.Duration // outputFormat is the raw flag value. resolveOutputFormat turns it into // the effective format (which may differ when stdout is piped). @@ -95,6 +96,7 @@ Limitations (this release): cmd.Flags().StringVarP(&f.database, "database", "d", defaultDatabase, "Database name") cmd.Flags().DurationVar(&f.connectTimeout, "connect-timeout", defaultConnectTimeout, "Connect timeout") cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (must be >= 1; 1 disables retry)") + cmd.Flags().DurationVar(&f.timeout, "timeout", 0, "Per-statement timeout (0 disables)") cmd.Flags().StringArrayVarP(&f.files, "file", "f", nil, "SQL file path (repeatable)") cmd.Flags().StringVarP(&f.outputFormat, "output", "o", string(outputText), "Output format: text, json, or csv") cmd.RegisterFlagCompletionFunc("output", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { @@ -172,10 +174,21 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla MaxDelay: 10 * time.Second, } - conn, err := connectWithRetry(ctx, pgxCfg, rc, pgx.ConnectConfig) + // Invocation-scoped context: cancelled by Ctrl+C/SIGTERM. Owns the + // connection lifecycle. Per-statement timeouts are children of this so + // a cancelled invocation also cancels the in-flight statement. + signalCtx, signalCancel := context.WithCancel(ctx) + defer signalCancel() + + stopSignals := watchInterruptSignals(signalCtx, signalCancel) + defer stopSignals() + + conn, err := connectWithRetry(signalCtx, pgxCfg, rc, pgx.ConnectConfig) if err != nil { return err } + // Close on a background ctx so a cancelled signalCtx does not abort a + // clean teardown handshake. defer conn.Close(context.WithoutCancel(ctx)) out := cmd.OutOrStdout() @@ -186,9 +199,15 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla // Avoids buffering rows for large exports and matches the v1 single- // input behaviour PR 2 shipped. Wrap the error so DETAIL / HINT // from a *pgconn.PgError surface even on the single-input path. - sink := newSink(format, out, stderr) - if err := executeOne(ctx, conn, units[0].SQL, sink); err != nil { - return errors.New(formatPgError(err)) + // Promote-to-interactive only when stdout is a prompt-capable TTY so + // a pipe falls back to the static table rather than launching a TUI + // into a dead writer. + sink := newSinkInteractive(format, out, stderr, stdoutTTY && cmdio.IsPromptSupported(ctx)) + stmtCtx, stmtCancel := withStatementTimeout(signalCtx, f.timeout) + err := executeOne(stmtCtx, conn, units[0].SQL, sink) + stmtCancel() + if err != nil { + return errors.New(reportCancellation(signalCtx, stmtCtx, err, f.timeout)) } return nil } @@ -199,7 +218,9 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla // temp tables) carries across units because we hold the same connection. results := make([]*unitResult, 0, len(units)) for _, u := range units { - r, err := runUnitBuffered(ctx, conn, u) + stmtCtx, stmtCancel := withStatementTimeout(signalCtx, f.timeout) + r, err := runUnitBuffered(stmtCtx, conn, u) + stmtCancel() if err != nil { // Render the successful prefix, then surface the error with // rich pgError formatting if applicable. @@ -208,7 +229,7 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla // error to the user, the renderer error to debug logs. fmt.Fprintln(stderr, "warning: failed to render partial result:", rerr) } - return formatExecutionError(u.Source, err) + return errors.New(u.Source + ": " + reportCancellation(signalCtx, stmtCtx, err, f.timeout)) } results = append(results, r) } @@ -221,15 +242,47 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla } } -// newSink returns the rowSink for the chosen output format. Kept separate -// from runQuery so tests can build sinks without going through pgx. -func newSink(format outputFormat, out, stderr io.Writer) rowSink { +// withStatementTimeout returns ctx unchanged (and a no-op cancel) when timeout +// is zero, otherwise a child context with the timeout applied. Per-statement +// scoping means a long-running unit can be cancelled without aborting the +// next unit's chance to run with a fresh deadline; today execution stops on +// the first failure either way, but the contract is what matters for v2. +func withStatementTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + if timeout <= 0 { + return parent, func() {} + } + return context.WithTimeout(parent, timeout) +} + +// reportCancellation distinguishes the three error cases that look the +// same from `executeOne`'s POV (a wrapped pgconn / network error): user +// cancelled via Ctrl+C, --timeout fired, or the statement just plain +// errored. Returns a human-readable message; the caller wraps it. +func reportCancellation(signalCtx, stmtCtx context.Context, err error, timeout time.Duration) string { + switch { + case errors.Is(signalCtx.Err(), context.Canceled): + return "Query cancelled." + case timeout > 0 && errors.Is(stmtCtx.Err(), context.DeadlineExceeded): + return fmt.Sprintf("Query timed out after %s.", timeout) + default: + return formatPgError(err) + } +} + +// newSinkInteractive returns the rowSink for the chosen output format. When +// interactive is true the text sink may launch the libs/tableview viewer for +// results larger than staticTableThreshold; when false it uses the static +// tabwriter table. +func newSinkInteractive(format outputFormat, out, stderr io.Writer, interactive bool) rowSink { switch format { case outputJSON: return newJSONSink(out, stderr) case outputCSV: return newCSVSink(out, stderr) default: + if interactive { + return newInteractiveTextSink(out) + } return newTextSink(out) } } @@ -248,13 +301,6 @@ func renderPartial(out, stderr io.Writer, format outputFormat, results []*unitRe } } -// formatExecutionError produces the error returned to cobra when an input -// unit failed. The message includes the source label so the user knows -// which of N inputs blew up. -func formatExecutionError(source string, err error) error { - return fmt.Errorf("%s: %s", source, formatPgError(err)) -} - // multiStatementHint is appended to errMultipleStatements so users see the // recovery path inline. const multiStatementHint = "\nThis command runs one statement per input. To run multiple statements:\n" + diff --git a/experimental/postgres/cmd/render.go b/experimental/postgres/cmd/render.go index 2e1daf6376b..9c027529a70 100644 --- a/experimental/postgres/cmd/render.go +++ b/experimental/postgres/cmd/render.go @@ -6,25 +6,44 @@ import ( "strings" "text/tabwriter" + "github.com/databricks/cli/libs/tableview" "github.com/jackc/pgx/v5/pgconn" ) +// staticTableThreshold is the row count above which we hand off to +// libs/tableview's interactive viewer (when stdout is interactive). Smaller +// results stay in the static tabwriter path so they stream to a pipe +// unchanged. Matches the threshold aitools query uses. +const staticTableThreshold = 30 + // textSink renders results as plain text: a tabwriter-aligned table for // rows-producing statements, the command tag for command-only ones. // // Text output buffers all rows because tabwriter needs the widest cell in each // column before it can align. Streaming output is provided by the JSON and CSV // sinks; users with huge result sets should pick those. +// +// When interactive is true and the result has more than staticTableThreshold +// rows, End hands off to libs/tableview's scrollable viewer instead of +// emitting the static table. The interactive path requires a real TTY and a +// prompt-capable terminal; the caller decides. type textSink struct { - out io.Writer - columns []string - rows [][]string + out io.Writer + interactive bool + columns []string + rows [][]string } func newTextSink(out io.Writer) *textSink { return &textSink{out: out} } +// newInteractiveTextSink returns a text sink that uses the interactive table +// viewer for results larger than staticTableThreshold. +func newInteractiveTextSink(out io.Writer) *textSink { + return &textSink{out: out, interactive: true} +} + func (s *textSink) Begin(fields []pgconn.FieldDescription) error { s.columns = make([]string, len(fields)) for i, f := range fields { @@ -48,6 +67,10 @@ func (s *textSink) End(commandTag string) error { return err } + if s.interactive && len(s.rows) > staticTableThreshold { + return tableview.Run(s.out, s.columns, s.rows) + } + tw := tabwriter.NewWriter(s.out, 0, 0, 2, ' ', 0) fmt.Fprintln(tw, strings.Join(s.columns, "\t")) fmt.Fprintln(tw, strings.Join(headerSeparator(s.columns), "\t")) diff --git a/experimental/postgres/cmd/signals.go b/experimental/postgres/cmd/signals.go new file mode 100644 index 00000000000..b946e6b3a01 --- /dev/null +++ b/experimental/postgres/cmd/signals.go @@ -0,0 +1,40 @@ +package postgrescmd + +import ( + "context" + "os" + "os/signal" + "syscall" +) + +// watchInterruptSignals installs handlers for SIGINT and SIGTERM that call +// cancel when the user hits Ctrl+C or the process gets a SIGTERM. +// +// Returns a stop function that uninstalls the handlers; the caller must defer +// it. Calling stop drains the signal channel so a queued signal that arrived +// during shutdown does not leak. +// +// On Windows, Go maps Ctrl+C to os.Interrupt via the console-control-handler. +// The same code path therefore works for the Windows runner; the integration +// test pins this expectation. +func watchInterruptSignals(ctx context.Context, cancel context.CancelFunc) func() { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + + done := make(chan struct{}) + go func() { + select { + case <-sigCh: + cancel() + case <-ctx.Done(): + } + close(done) + }() + + return func() { + signal.Stop(sigCh) + // Wake the goroutine in case neither sigCh nor ctx.Done has fired. + cancel() + <-done + } +} diff --git a/integration/cmd/postgres/postgres_test.go b/integration/cmd/postgres/postgres_test.go new file mode 100644 index 00000000000..971c87ad013 --- /dev/null +++ b/integration/cmd/postgres/postgres_test.go @@ -0,0 +1,109 @@ +// Package postgres_test contains integration tests for the experimental +// `databricks experimental postgres query` command. Skipped unless an +// autoscaling resource path or provisioned instance name is provided +// via DATABRICKS_POSTGRES_INTEGRATION_TARGET. +// +// To run locally against a real Lakebase endpoint: +// +// export DATABRICKS_POSTGRES_INTEGRATION_TARGET=projects/foo/branches/main/endpoints/primary +// go test ./integration/cmd/postgres/... -v +package postgres_test + +import ( + "os" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + _ "github.com/databricks/cli/cmd/experimental" + "github.com/databricks/cli/internal/testcli" +) + +// targetEnv is the env var that gates these tests. Either a provisioned +// instance name or an autoscaling resource path; the command picks the +// right resolver based on the leading "projects/" segment. +const targetEnv = "DATABRICKS_POSTGRES_INTEGRATION_TARGET" + +func requireTarget(t *testing.T) string { + target := os.Getenv(targetEnv) + if target == "" { + t.Skipf("set %s to run postgres integration tests", targetEnv) + } + return target +} + +func TestPostgresQuery_SimpleSelect(t *testing.T) { + target := requireTarget(t) + ctx := t.Context() + + stdout, _ := testcli.RequireSuccessfulRun(t, ctx, "experimental", "postgres", "query", + "--target", target, "--output", "json", "SELECT 1 AS x") + + out := stdout.String() + assert.Contains(t, out, `"x":1`) +} + +func TestPostgresQuery_CommandOnly(t *testing.T) { + target := requireTarget(t) + ctx := t.Context() + + stdout, _ := testcli.RequireSuccessfulRun(t, ctx, "experimental", "postgres", "query", + "--target", target, "--output", "json", "SET search_path TO public") + + out := stdout.String() + assert.Contains(t, out, `"command":"SET"`) +} + +func TestPostgresQuery_TimeoutFires(t *testing.T) { + target := requireTarget(t) + ctx := t.Context() + + // pg_sleep(5) with --timeout 1s should fail in well under 5s. + start := time.Now() + _, stderr, err := testcli.RequireErrorRun(t, ctx, "experimental", "postgres", "query", + "--target", target, "--timeout", "1s", "SELECT pg_sleep(5)") + require.Error(t, err) + assert.Less(t, time.Since(start), 5*time.Second, "--timeout should cancel before pg_sleep finishes") + assert.Contains(t, stderr.String(), "timed out after 1s") +} + +func TestPostgresQuery_CancelOnInterrupt(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Ctrl+C signal-driven cancel test is run via a separate harness on Windows") + } + requireTarget(t) + t.Skip("manual: signal-driven cancel must be exercised with a child process; see plan section 'Cancellation and timeout'") +} + +func TestPostgresQuery_StreamingCSV(t *testing.T) { + target := requireTarget(t) + ctx := t.Context() + + // generate_series streams via pgx without buffering into memory; pick a + // small-but-non-trivial bound so the test stays fast. + stdout, _ := testcli.RequireSuccessfulRun(t, ctx, "experimental", "postgres", "query", + "--target", target, "--output", "csv", "SELECT * FROM generate_series(1, 100) AS s") + + lines := strings.Split(strings.TrimRight(stdout.String(), "\n"), "\n") + assert.GreaterOrEqual(t, len(lines), 101, "expected header + 100 rows") + assert.Equal(t, "s", lines[0]) +} + +func TestPostgresQuery_MultiInputJSON(t *testing.T) { + target := requireTarget(t) + ctx := t.Context() + + stdout, _ := testcli.RequireSuccessfulRun(t, ctx, "experimental", "postgres", "query", + "--target", target, "--output", "json", + "SELECT 1 AS a", "SELECT 2 AS b") + + out := stdout.String() + assert.Contains(t, out, `"sql":"SELECT 1 AS a"`) + assert.Contains(t, out, `"sql":"SELECT 2 AS b"`) + assert.Contains(t, out, `"a":1`) + assert.Contains(t, out, `"b":2`) +} From 43b26670a41a88c8e5769bdeafed37558e3fda6a Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 12:02:58 +0200 Subject: [PATCH 12/25] Address PR 4 review feedback round 1 SHOULDs: - signals.go: drop the false "drains the signal channel" claim from the doc comment. signal.Stop blocks future deliveries; the 1-buffered channel is GC'd on return so no explicit drain is needed. - integration test: drop the dead `_ "cmd/experimental"` blank import (testcli.NewRunner already pulls cmd/experimental in transitively). - integration test: delete the cancel-on-interrupt stub; documented as a follow-up because Ctrl+C testing requires a child-process harness that's outside the scope of this PR. - query.go: when an invocation-scoped error fires (Ctrl+C, --timeout) in multi-input mode, drop the `:` prefix. The user knows which invocation they cancelled; "--file foo.sql: Query cancelled." reads worse than "Query cancelled." reportCancellation now returns (msg, invocationScoped) so the caller picks the right shape. - withStatementTimeout: trim the v2-speculation from the doc comment. CONSIDERs: - C2: rename watchInterruptSignals's stop closure semantics to acknowledge it cancels the parent ctx as a side effect. - C4: TestReportCancellation_BothFire_CancelWinsRace pins the precedence (user cancel beats coincidental deadline). - C6: drop the redundant require.Error after RequireErrorRun (which already calls require.Error internally). Plus integration test polish: - Parse JSON outputs instead of substring-matching so encoder drift doesn't break tests. - Tighten timeout assertion from <5s to <3s so a regression to TCP-keepalive timeout (~minutes) would show. - Bump generate_series bound from 100 to 100k so streaming actually exercises memory pressure. Co-authored-by: Isaac --- experimental/postgres/cmd/cancel_test.go | 30 ++++++++--- experimental/postgres/cmd/query.go | 42 ++++++++++------ experimental/postgres/cmd/signals.go | 12 ++--- integration/cmd/postgres/postgres_test.go | 61 ++++++++++++----------- 4 files changed, 89 insertions(+), 56 deletions(-) diff --git a/experimental/postgres/cmd/cancel_test.go b/experimental/postgres/cmd/cancel_test.go index 73de49ef6bb..4245b905efc 100644 --- a/experimental/postgres/cmd/cancel_test.go +++ b/experimental/postgres/cmd/cancel_test.go @@ -33,25 +33,41 @@ func TestReportCancellation_SignalCanceled(t *testing.T) { signalCtx, signalCancel := context.WithCancel(t.Context()) signalCancel() stmtCtx := signalCtx - got := reportCancellation(signalCtx, stmtCtx, errors.New("anything"), 0) - assert.Equal(t, "Query cancelled.", got) + msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, errors.New("anything"), 0) + assert.Equal(t, "Query cancelled.", msg) + assert.True(t, invocationScoped) } func TestReportCancellation_TimeoutFired(t *testing.T) { signalCtx := t.Context() stmtCtx, stmtCancel := context.WithDeadline(signalCtx, time.Now().Add(-time.Second)) defer stmtCancel() - // Wait for the deadline to be surfaced. <-stmtCtx.Done() - got := reportCancellation(signalCtx, stmtCtx, errors.New("query failed"), 5*time.Second) - assert.Equal(t, "Query timed out after 5s.", got) + msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, errors.New("query failed"), 5*time.Second) + assert.Equal(t, "Query timed out after 5s.", msg) + assert.True(t, invocationScoped) } func TestReportCancellation_GenericError(t *testing.T) { signalCtx := t.Context() stmtCtx := signalCtx - got := reportCancellation(signalCtx, stmtCtx, errors.New("syntax error"), 0) - assert.Equal(t, "syntax error", got) + msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, errors.New("syntax error"), 0) + assert.Equal(t, "syntax error", msg) + assert.False(t, invocationScoped) +} + +func TestReportCancellation_BothFire_CancelWinsRace(t *testing.T) { + // User cancel and deadline both already done. Precedence: cancel wins + // (the user's intent dominates a coincidental deadline). A future + // reordering of the switch would silently flip this; the test pins it. + signalCtx, signalCancel := context.WithCancel(t.Context()) + signalCancel() + stmtCtx, stmtCancel := context.WithDeadline(signalCtx, time.Now().Add(-time.Second)) + defer stmtCancel() + <-stmtCtx.Done() + msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, errors.New("anything"), time.Second) + assert.Equal(t, "Query cancelled.", msg) + assert.True(t, invocationScoped) } func TestWatchInterruptSignals_CancelOnStop(t *testing.T) { diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index 8f5e7c5ab66..c8ed4210591 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -207,7 +207,8 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla err := executeOne(stmtCtx, conn, units[0].SQL, sink) stmtCancel() if err != nil { - return errors.New(reportCancellation(signalCtx, stmtCtx, err, f.timeout)) + msg, _ := reportCancellation(signalCtx, stmtCtx, err, f.timeout) + return errors.New(msg) } return nil } @@ -229,7 +230,14 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla // error to the user, the renderer error to debug logs. fmt.Fprintln(stderr, "warning: failed to render partial result:", rerr) } - return errors.New(u.Source + ": " + reportCancellation(signalCtx, stmtCtx, err, f.timeout)) + msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, err, f.timeout) + if invocationScoped { + // User cancel / timeout is invocation-scoped; the source + // prefix is redundant ("--file foo.sql: Query cancelled." + // reads worse than just "Query cancelled."). + return errors.New(msg) + } + return errors.New(u.Source + ": " + msg) } results = append(results, r) } @@ -242,11 +250,10 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla } } -// withStatementTimeout returns ctx unchanged (and a no-op cancel) when timeout -// is zero, otherwise a child context with the timeout applied. Per-statement -// scoping means a long-running unit can be cancelled without aborting the -// next unit's chance to run with a fresh deadline; today execution stops on -// the first failure either way, but the contract is what matters for v2. +// withStatementTimeout returns ctx unchanged (and a no-op cancel) when +// timeout is zero, otherwise a child context with the timeout applied. Each +// statement gets its own deadline so cancellation is scoped to one +// statement at a time. func withStatementTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { if timeout <= 0 { return parent, func() {} @@ -254,18 +261,23 @@ func withStatementTimeout(parent context.Context, timeout time.Duration) (contex return context.WithTimeout(parent, timeout) } -// reportCancellation distinguishes the three error cases that look the -// same from `executeOne`'s POV (a wrapped pgconn / network error): user -// cancelled via Ctrl+C, --timeout fired, or the statement just plain -// errored. Returns a human-readable message; the caller wraps it. -func reportCancellation(signalCtx, stmtCtx context.Context, err error, timeout time.Duration) string { +// reportCancellation distinguishes the three error cases that look the same +// from `executeOne`'s POV (a wrapped pgconn / network error): user cancelled +// via Ctrl+C, --timeout fired, or the statement just plain errored. Returns +// the human-readable message and whether the cause is invocation-scoped +// (cancel/timeout) rather than statement-scoped. +// +// Precedence: user cancel beats deadline. If both contexts fire near- +// simultaneously (race), we report "cancelled" because the user's intent +// dominates a coincidental timeout. +func reportCancellation(signalCtx, stmtCtx context.Context, err error, timeout time.Duration) (msg string, invocationScoped bool) { switch { case errors.Is(signalCtx.Err(), context.Canceled): - return "Query cancelled." + return "Query cancelled.", true case timeout > 0 && errors.Is(stmtCtx.Err(), context.DeadlineExceeded): - return fmt.Sprintf("Query timed out after %s.", timeout) + return fmt.Sprintf("Query timed out after %s.", timeout), true default: - return formatPgError(err) + return formatPgError(err), false } } diff --git a/experimental/postgres/cmd/signals.go b/experimental/postgres/cmd/signals.go index b946e6b3a01..5e4c29346f9 100644 --- a/experimental/postgres/cmd/signals.go +++ b/experimental/postgres/cmd/signals.go @@ -10,13 +10,13 @@ import ( // watchInterruptSignals installs handlers for SIGINT and SIGTERM that call // cancel when the user hits Ctrl+C or the process gets a SIGTERM. // -// Returns a stop function that uninstalls the handlers; the caller must defer -// it. Calling stop drains the signal channel so a queued signal that arrived -// during shutdown does not leak. +// Returns a stop-and-cancel function that uninstalls the handlers (signal.Stop +// prevents future OS deliveries) and cancels the parent context so the +// goroutine wakes promptly. The caller must defer it. The channel is +// 1-buffered and GC'd on return; no explicit drain is needed. // -// On Windows, Go maps Ctrl+C to os.Interrupt via the console-control-handler. -// The same code path therefore works for the Windows runner; the integration -// test pins this expectation. +// On Windows, Go maps Ctrl+C to os.Interrupt via the console-control-handler, +// so the same code path covers Windows. func watchInterruptSignals(ctx context.Context, cancel context.CancelFunc) func() { sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) diff --git a/integration/cmd/postgres/postgres_test.go b/integration/cmd/postgres/postgres_test.go index 971c87ad013..32800b3b220 100644 --- a/integration/cmd/postgres/postgres_test.go +++ b/integration/cmd/postgres/postgres_test.go @@ -3,15 +3,21 @@ // autoscaling resource path or provisioned instance name is provided // via DATABRICKS_POSTGRES_INTEGRATION_TARGET. // -// To run locally against a real Lakebase endpoint: +// To run locally against a real Lakebase endpoint, set both the standard +// auth env (DATABRICKS_HOST + DATABRICKS_TOKEN, or a configured profile) +// and the target: // // export DATABRICKS_POSTGRES_INTEGRATION_TARGET=projects/foo/branches/main/endpoints/primary // go test ./integration/cmd/postgres/... -v +// +// Ctrl+C cancellation is intentionally not in this suite: it requires a +// child-process harness (the test runner cannot share signal handlers +// with the in-process command). Tracked as a follow-up. package postgres_test import ( + "encoding/json" "os" - "runtime" "strings" "testing" "time" @@ -19,7 +25,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - _ "github.com/databricks/cli/cmd/experimental" "github.com/databricks/cli/internal/testcli" ) @@ -43,8 +48,12 @@ func TestPostgresQuery_SimpleSelect(t *testing.T) { stdout, _ := testcli.RequireSuccessfulRun(t, ctx, "experimental", "postgres", "query", "--target", target, "--output", "json", "SELECT 1 AS x") - out := stdout.String() - assert.Contains(t, out, `"x":1`) + // Parsing the JSON instead of substring-matching makes the test robust + // to encoder formatting drift (whitespace, key order). + var rows []map[string]any + require.NoError(t, json.Unmarshal(stdout.Bytes(), &rows)) + require.Len(t, rows, 1) + assert.EqualValues(t, 1, rows[0]["x"]) } func TestPostgresQuery_CommandOnly(t *testing.T) { @@ -54,42 +63,38 @@ func TestPostgresQuery_CommandOnly(t *testing.T) { stdout, _ := testcli.RequireSuccessfulRun(t, ctx, "experimental", "postgres", "query", "--target", target, "--output", "json", "SET search_path TO public") - out := stdout.String() - assert.Contains(t, out, `"command":"SET"`) + var obj map[string]any + require.NoError(t, json.Unmarshal(stdout.Bytes(), &obj)) + assert.Equal(t, "SET", obj["command"]) } func TestPostgresQuery_TimeoutFires(t *testing.T) { target := requireTarget(t) ctx := t.Context() - // pg_sleep(5) with --timeout 1s should fail in well under 5s. + // pg_sleep(5) with --timeout 1s should fail well within the watcher's + // 5s DeadlineDelay. A loose <5s bound would still pass even if the + // watcher silently regressed to TCP-keepalive timeout (~minutes); the + // tighter <3s bound catches that. start := time.Now() - _, stderr, err := testcli.RequireErrorRun(t, ctx, "experimental", "postgres", "query", + _, stderr, _ := testcli.RequireErrorRun(t, ctx, "experimental", "postgres", "query", "--target", target, "--timeout", "1s", "SELECT pg_sleep(5)") - require.Error(t, err) - assert.Less(t, time.Since(start), 5*time.Second, "--timeout should cancel before pg_sleep finishes") + assert.Less(t, time.Since(start), 3*time.Second, "--timeout should cancel before pg_sleep finishes") assert.Contains(t, stderr.String(), "timed out after 1s") } -func TestPostgresQuery_CancelOnInterrupt(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("Ctrl+C signal-driven cancel test is run via a separate harness on Windows") - } - requireTarget(t) - t.Skip("manual: signal-driven cancel must be exercised with a child process; see plan section 'Cancellation and timeout'") -} - func TestPostgresQuery_StreamingCSV(t *testing.T) { target := requireTarget(t) ctx := t.Context() - // generate_series streams via pgx without buffering into memory; pick a - // small-but-non-trivial bound so the test stays fast. + // 100k rows is large enough to exercise streaming under realistic memory + // pressure (the buffered text path would still complete but allocate + // the whole result; the streaming CSV path keeps allocations bounded). stdout, _ := testcli.RequireSuccessfulRun(t, ctx, "experimental", "postgres", "query", - "--target", target, "--output", "csv", "SELECT * FROM generate_series(1, 100) AS s") + "--target", target, "--output", "csv", "SELECT * FROM generate_series(1, 100000) AS s") lines := strings.Split(strings.TrimRight(stdout.String(), "\n"), "\n") - assert.GreaterOrEqual(t, len(lines), 101, "expected header + 100 rows") + assert.GreaterOrEqual(t, len(lines), 100001, "expected header + 100000 rows") assert.Equal(t, "s", lines[0]) } @@ -101,9 +106,9 @@ func TestPostgresQuery_MultiInputJSON(t *testing.T) { "--target", target, "--output", "json", "SELECT 1 AS a", "SELECT 2 AS b") - out := stdout.String() - assert.Contains(t, out, `"sql":"SELECT 1 AS a"`) - assert.Contains(t, out, `"sql":"SELECT 2 AS b"`) - assert.Contains(t, out, `"a":1`) - assert.Contains(t, out, `"b":2`) + var results []map[string]any + require.NoError(t, json.Unmarshal(stdout.Bytes(), &results)) + require.Len(t, results, 2) + assert.Equal(t, "SELECT 1 AS a", results[0]["sql"]) + assert.Equal(t, "SELECT 2 AS b", results[1]["sql"]) } From c57805f858693dca9b7f42fc47b7e3937c73c3ba Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 12:07:33 +0200 Subject: [PATCH 13/25] PR 4 r2: warm-up TimeoutFires test before timing --timeout Round-2 reviewer flagged that the previous <3s bound was tight enough to flake on a cold Lakebase autoscaling endpoint, where auth + connect + cold-start can plausibly take >2s on its own. The regression we actually want to catch (silent fall-back to TCP keepalive) takes minutes, so <5s is enough. Add a warm-up RequireSuccessfulRun before timing so the assertion measures what it claims to measure: how long the 1-second deadline takes to actually cancel the in-flight statement. Co-authored-by: Isaac --- integration/cmd/postgres/postgres_test.go | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/integration/cmd/postgres/postgres_test.go b/integration/cmd/postgres/postgres_test.go index 32800b3b220..76d3127ccb5 100644 --- a/integration/cmd/postgres/postgres_test.go +++ b/integration/cmd/postgres/postgres_test.go @@ -72,14 +72,22 @@ func TestPostgresQuery_TimeoutFires(t *testing.T) { target := requireTarget(t) ctx := t.Context() + // Warm up first: pay the auth + connect (and potential cold-start) + // cost before timing the --timeout assertion. Without this, a cold + // Lakebase autoscaling endpoint could push the timed run past any + // reasonable deadline even though --timeout did exactly the right + // thing. Now `start` measures what we care about: how long the + // 1-second deadline takes to actually cancel the in-flight statement. + testcli.RequireSuccessfulRun(t, ctx, "experimental", "postgres", "query", + "--target", target, "--output", "json", "SELECT 1") + // pg_sleep(5) with --timeout 1s should fail well within the watcher's - // 5s DeadlineDelay. A loose <5s bound would still pass even if the - // watcher silently regressed to TCP-keepalive timeout (~minutes); the - // tighter <3s bound catches that. + // 5s DeadlineDelay. <5s rules out a silent regression to the + // TCP-keepalive timeout (~minutes). start := time.Now() _, stderr, _ := testcli.RequireErrorRun(t, ctx, "experimental", "postgres", "query", "--target", target, "--timeout", "1s", "SELECT pg_sleep(5)") - assert.Less(t, time.Since(start), 3*time.Second, "--timeout should cancel before pg_sleep finishes") + assert.Less(t, time.Since(start), 5*time.Second, "--timeout should cancel before pg_sleep finishes") assert.Contains(t, stderr.String(), "timed out after 1s") } From 6cd264e55f81a406dbc3d59d169c9574c315c964 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 12:21:48 +0200 Subject: [PATCH 14/25] Drop integration tests for experimental command Other experimental commands (aitools) have no integration tests; an experimental command is by definition pre-stabilization, and gating its real-wire test on a custom env var introduces friction without a clear win. Acceptance tests + unit tests already cover argument validation, targeting resolution (SDK-mocked), and the streaming / multi-input output shapes. The cancellation watcher and --timeout are unit-tested via the seam in cancel_test.go. When this command graduates from experimental, integration tests are the right addition; for v1 they were over-engineered. Co-authored-by: Isaac --- integration/cmd/postgres/postgres_test.go | 122 ---------------------- 1 file changed, 122 deletions(-) delete mode 100644 integration/cmd/postgres/postgres_test.go diff --git a/integration/cmd/postgres/postgres_test.go b/integration/cmd/postgres/postgres_test.go deleted file mode 100644 index 76d3127ccb5..00000000000 --- a/integration/cmd/postgres/postgres_test.go +++ /dev/null @@ -1,122 +0,0 @@ -// Package postgres_test contains integration tests for the experimental -// `databricks experimental postgres query` command. Skipped unless an -// autoscaling resource path or provisioned instance name is provided -// via DATABRICKS_POSTGRES_INTEGRATION_TARGET. -// -// To run locally against a real Lakebase endpoint, set both the standard -// auth env (DATABRICKS_HOST + DATABRICKS_TOKEN, or a configured profile) -// and the target: -// -// export DATABRICKS_POSTGRES_INTEGRATION_TARGET=projects/foo/branches/main/endpoints/primary -// go test ./integration/cmd/postgres/... -v -// -// Ctrl+C cancellation is intentionally not in this suite: it requires a -// child-process harness (the test runner cannot share signal handlers -// with the in-process command). Tracked as a follow-up. -package postgres_test - -import ( - "encoding/json" - "os" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/databricks/cli/internal/testcli" -) - -// targetEnv is the env var that gates these tests. Either a provisioned -// instance name or an autoscaling resource path; the command picks the -// right resolver based on the leading "projects/" segment. -const targetEnv = "DATABRICKS_POSTGRES_INTEGRATION_TARGET" - -func requireTarget(t *testing.T) string { - target := os.Getenv(targetEnv) - if target == "" { - t.Skipf("set %s to run postgres integration tests", targetEnv) - } - return target -} - -func TestPostgresQuery_SimpleSelect(t *testing.T) { - target := requireTarget(t) - ctx := t.Context() - - stdout, _ := testcli.RequireSuccessfulRun(t, ctx, "experimental", "postgres", "query", - "--target", target, "--output", "json", "SELECT 1 AS x") - - // Parsing the JSON instead of substring-matching makes the test robust - // to encoder formatting drift (whitespace, key order). - var rows []map[string]any - require.NoError(t, json.Unmarshal(stdout.Bytes(), &rows)) - require.Len(t, rows, 1) - assert.EqualValues(t, 1, rows[0]["x"]) -} - -func TestPostgresQuery_CommandOnly(t *testing.T) { - target := requireTarget(t) - ctx := t.Context() - - stdout, _ := testcli.RequireSuccessfulRun(t, ctx, "experimental", "postgres", "query", - "--target", target, "--output", "json", "SET search_path TO public") - - var obj map[string]any - require.NoError(t, json.Unmarshal(stdout.Bytes(), &obj)) - assert.Equal(t, "SET", obj["command"]) -} - -func TestPostgresQuery_TimeoutFires(t *testing.T) { - target := requireTarget(t) - ctx := t.Context() - - // Warm up first: pay the auth + connect (and potential cold-start) - // cost before timing the --timeout assertion. Without this, a cold - // Lakebase autoscaling endpoint could push the timed run past any - // reasonable deadline even though --timeout did exactly the right - // thing. Now `start` measures what we care about: how long the - // 1-second deadline takes to actually cancel the in-flight statement. - testcli.RequireSuccessfulRun(t, ctx, "experimental", "postgres", "query", - "--target", target, "--output", "json", "SELECT 1") - - // pg_sleep(5) with --timeout 1s should fail well within the watcher's - // 5s DeadlineDelay. <5s rules out a silent regression to the - // TCP-keepalive timeout (~minutes). - start := time.Now() - _, stderr, _ := testcli.RequireErrorRun(t, ctx, "experimental", "postgres", "query", - "--target", target, "--timeout", "1s", "SELECT pg_sleep(5)") - assert.Less(t, time.Since(start), 5*time.Second, "--timeout should cancel before pg_sleep finishes") - assert.Contains(t, stderr.String(), "timed out after 1s") -} - -func TestPostgresQuery_StreamingCSV(t *testing.T) { - target := requireTarget(t) - ctx := t.Context() - - // 100k rows is large enough to exercise streaming under realistic memory - // pressure (the buffered text path would still complete but allocate - // the whole result; the streaming CSV path keeps allocations bounded). - stdout, _ := testcli.RequireSuccessfulRun(t, ctx, "experimental", "postgres", "query", - "--target", target, "--output", "csv", "SELECT * FROM generate_series(1, 100000) AS s") - - lines := strings.Split(strings.TrimRight(stdout.String(), "\n"), "\n") - assert.GreaterOrEqual(t, len(lines), 100001, "expected header + 100000 rows") - assert.Equal(t, "s", lines[0]) -} - -func TestPostgresQuery_MultiInputJSON(t *testing.T) { - target := requireTarget(t) - ctx := t.Context() - - stdout, _ := testcli.RequireSuccessfulRun(t, ctx, "experimental", "postgres", "query", - "--target", target, "--output", "json", - "SELECT 1 AS a", "SELECT 2 AS b") - - var results []map[string]any - require.NoError(t, json.Unmarshal(stdout.Bytes(), &results)) - require.Len(t, results, 2) - assert.Equal(t, "SELECT 1 AS a", results[0]["sql"]) - assert.Equal(t, "SELECT 2 AS b", results[1]["sql"]) -} From 5a27bf0034e479666ae6a7aa28105721a0c2fe6a Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 12:58:53 +0200 Subject: [PATCH 15/25] Cut blast radius: keep target package + acceptance tests inside experimental Two reductions for an experimental command, per a maintainer comment: - Move libs/lakebase/target into experimental/postgres/cmd/internal/target so the experiment is self-contained. cmd/psql is no longer touched (no refactor, no behavior change). When/if this command graduates from experimental, that's the right time to extract the shared package. - Drop acceptance tests for the new command. Aitools (the other experimental command) has none either; locking down user-visible wording for an experimental surface is overinvestment. Unit tests still cover argument validation, retry classification, and rendering. Acceptance tests can be added when the command graduates. Net diff on cmd/psql is now zero. The experiment lives entirely under experimental/postgres/cmd/. Co-authored-by: Isaac --- .../query/ambiguous-targeting/out.test.toml | 8 -- .../query/ambiguous-targeting/output.txt | 18 --- .../postgres/query/ambiguous-targeting/script | 8 -- .../query/ambiguous-targeting/test.toml | 62 -------- .../query/argument-errors/out.test.toml | 8 -- .../postgres/query/argument-errors/output.txt | 48 ------- .../postgres/query/argument-errors/script | 35 ----- .../postgres/query/argument-errors/test.toml | 3 - .../cmd/psql/argument-errors/output.txt | 4 - acceptance/cmd/psql/argument-errors/script | 3 - acceptance/cmd/psql/postgres/output.txt | 2 +- cmd/psql/psql.go | 61 ++++++-- cmd/psql/psql_autoscaling.go | 132 ++++++++++++------ cmd/psql/psql_provisioned.go | 47 +++++-- cmd/psql/psql_test.go | 83 +++++++++++ .../cmd/internal}/target/autoscaling.go | 0 .../cmd/internal}/target/provisioned.go | 0 .../postgres/cmd/internal}/target/target.go | 0 .../cmd/internal}/target/target_test.go | 0 experimental/postgres/cmd/targeting.go | 2 +- 20 files changed, 257 insertions(+), 267 deletions(-) delete mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml delete mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt delete mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script delete mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml delete mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml delete mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/output.txt delete mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/script delete mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/test.toml create mode 100644 cmd/psql/psql_test.go rename {libs/lakebase => experimental/postgres/cmd/internal}/target/autoscaling.go (100%) rename {libs/lakebase => experimental/postgres/cmd/internal}/target/provisioned.go (100%) rename {libs/lakebase => experimental/postgres/cmd/internal}/target/target.go (100%) rename {libs/lakebase => experimental/postgres/cmd/internal}/target/target_test.go (100%) diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml deleted file mode 100644 index 40bb0d10471..00000000000 --- a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml +++ /dev/null @@ -1,8 +0,0 @@ -Local = true -Cloud = false - -[GOOS] - windows = false - -[EnvMatrix] - DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt deleted file mode 100644 index e95a7b3613d..00000000000 --- a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt +++ /dev/null @@ -1,18 +0,0 @@ - -=== Project with multiple branches and no --branch should error with choices: ->>> musterr [CLI] experimental postgres query --project foo SELECT 1 -Error: multiple branches found in projects/foo; specify --branch: - - main - - dev - -=== Project with multiple endpoints in only branch should error with choices: ->>> musterr [CLI] experimental postgres query --project bar SELECT 1 -Error: multiple endpoints found in projects/bar/branches/only; specify --endpoint: - - read-write - - read-only - -=== Partial path with multiple branches should error with choices: ->>> musterr [CLI] experimental postgres query --target projects/foo SELECT 1 -Error: multiple branches found in projects/foo; specify --branch: - - main - - dev diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script deleted file mode 100644 index 6143fd96f02..00000000000 --- a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script +++ /dev/null @@ -1,8 +0,0 @@ -title "Project with multiple branches and no --branch should error with choices:" -trace musterr $CLI experimental postgres query --project foo "SELECT 1" - -title "Project with multiple endpoints in only branch should error with choices:" -trace musterr $CLI experimental postgres query --project bar "SELECT 1" - -title "Partial path with multiple branches should error with choices:" -trace musterr $CLI experimental postgres query --target projects/foo "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml deleted file mode 100644 index 2a61e7e8e25..00000000000 --- a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml +++ /dev/null @@ -1,62 +0,0 @@ -GOOS.windows = false - -[[Server]] -Pattern = "GET /api/2.0/postgres/projects" -Response.Body = ''' -{ - "projects": [ - {"name": "projects/alpha", "status": {"display_name": "Alpha"}}, - {"name": "projects/beta", "status": {"display_name": "Beta"}} - ] -} -''' - -[[Server]] -Pattern = "GET /api/2.0/postgres/projects/foo" -Response.Body = ''' -{ - "name": "projects/foo", - "status": {"display_name": "Foo Project"} -} -''' - -[[Server]] -Pattern = "GET /api/2.0/postgres/projects/foo/branches" -Response.Body = ''' -{ - "branches": [ - {"name": "projects/foo/branches/main"}, - {"name": "projects/foo/branches/dev"} - ] -} -''' - -[[Server]] -Pattern = "GET /api/2.0/postgres/projects/bar" -Response.Body = ''' -{ - "name": "projects/bar", - "status": {"display_name": "Bar Project"} -} -''' - -[[Server]] -Pattern = "GET /api/2.0/postgres/projects/bar/branches" -Response.Body = ''' -{ - "branches": [ - {"name": "projects/bar/branches/only"} - ] -} -''' - -[[Server]] -Pattern = "GET /api/2.0/postgres/projects/bar/branches/only/endpoints" -Response.Body = ''' -{ - "endpoints": [ - {"name": "projects/bar/branches/only/endpoints/read-write"}, - {"name": "projects/bar/branches/only/endpoints/read-only"} - ] -} -''' diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml b/acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml deleted file mode 100644 index 40bb0d10471..00000000000 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml +++ /dev/null @@ -1,8 +0,0 @@ -Local = true -Cloud = false - -[GOOS] - windows = false - -[EnvMatrix] - DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt deleted file mode 100644 index c071466a1e3..00000000000 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt +++ /dev/null @@ -1,48 +0,0 @@ - -=== No SQL argument should error: ->>> musterr [CLI] experimental postgres query --target projects/foo -Error: accepts 1 arg(s), received 0 - -=== Empty SQL should error: ->>> musterr [CLI] experimental postgres query --target projects/foo -Error: no SQL provided - -=== Neither targeting form should error: ->>> musterr [CLI] experimental postgres query SELECT 1 -Error: must specify --target or --project - -=== Both --target and --project should error: ->>> musterr [CLI] experimental postgres query --target projects/foo --project foo SELECT 1 -Error: if any flags in the group [target project] are set none of the others can be; [project target] were all set - -=== Both --target and --branch should error: ->>> musterr [CLI] experimental postgres query --target projects/foo --branch main SELECT 1 -Error: if any flags in the group [target branch] are set none of the others can be; [branch target] were all set - -=== Branch without project should error: ->>> musterr [CLI] experimental postgres query --branch main SELECT 1 -Error: --project is required when using --branch or --endpoint - -=== Endpoint without project should error: ->>> musterr [CLI] experimental postgres query --endpoint primary SELECT 1 -Error: --project is required when using --branch or --endpoint - -=== Endpoint without branch should error: ->>> musterr [CLI] experimental postgres query --project foo --endpoint primary SELECT 1 -Error: --branch is required when using --endpoint - -=== Max-retries 0 should error: ->>> musterr [CLI] experimental postgres query --project foo --branch main --max-retries 0 SELECT 1 -Error: --max-retries must be at least 1; got 0 - -=== Provisioned-shaped target should error pointing at psql: ->>> musterr [CLI] experimental postgres query --target my-instance SELECT 1 -Error: provisioned instances are not yet supported by this experimental command; use 'databricks psql ' for now - -=== Malformed autoscaling path should error: ->>> musterr [CLI] experimental postgres query --target projects/ SELECT 1 -Error: invalid resource path: missing project ID - -=== Trailing components after endpoint should error: ->>> musterr [CLI] experimental postgres query --target projects/foo/branches/bar/endpoints/baz/extra SELECT 1 -Error: invalid resource path: trailing components after endpoint: projects/foo/branches/bar/endpoints/baz/extra diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/script b/acceptance/cmd/experimental/postgres/query/argument-errors/script deleted file mode 100644 index 8d64bf307ed..00000000000 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/script +++ /dev/null @@ -1,35 +0,0 @@ -title "No SQL argument should error:" -trace musterr $CLI experimental postgres query --target projects/foo - -title "Empty SQL should error:" -trace musterr $CLI experimental postgres query --target projects/foo " " - -title "Neither targeting form should error:" -trace musterr $CLI experimental postgres query "SELECT 1" - -title "Both --target and --project should error:" -trace musterr $CLI experimental postgres query --target projects/foo --project foo "SELECT 1" - -title "Both --target and --branch should error:" -trace musterr $CLI experimental postgres query --target projects/foo --branch main "SELECT 1" - -title "Branch without project should error:" -trace musterr $CLI experimental postgres query --branch main "SELECT 1" - -title "Endpoint without project should error:" -trace musterr $CLI experimental postgres query --endpoint primary "SELECT 1" - -title "Endpoint without branch should error:" -trace musterr $CLI experimental postgres query --project foo --endpoint primary "SELECT 1" - -title "Max-retries 0 should error:" -trace musterr $CLI experimental postgres query --project foo --branch main --max-retries 0 "SELECT 1" - -title "Provisioned-shaped target should error pointing at psql:" -trace musterr $CLI experimental postgres query --target my-instance "SELECT 1" - -title "Malformed autoscaling path should error:" -trace musterr $CLI experimental postgres query --target projects/ "SELECT 1" - -title "Trailing components after endpoint should error:" -trace musterr $CLI experimental postgres query --target projects/foo/branches/bar/endpoints/baz/extra "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/test.toml b/acceptance/cmd/experimental/postgres/query/argument-errors/test.toml deleted file mode 100644 index 3371f08de12..00000000000 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/test.toml +++ /dev/null @@ -1,3 +0,0 @@ -# Argument validation runs before any SDK call. No mocked HTTP responses are -# needed; CLI either errors at flag-parse time or at our own validate function. -GOOS.windows = false diff --git a/acceptance/cmd/psql/argument-errors/output.txt b/acceptance/cmd/psql/argument-errors/output.txt index cbf6c093b21..35da5961dec 100644 --- a/acceptance/cmd/psql/argument-errors/output.txt +++ b/acceptance/cmd/psql/argument-errors/output.txt @@ -59,10 +59,6 @@ Error: invalid resource path: missing branch ID >>> musterr [CLI] psql projects/my-project/branches/main/endpoints/ Error: invalid resource path: missing endpoint ID -=== Trailing components after endpoint should error: ->>> musterr [CLI] psql projects/my-project/branches/main/endpoints/primary/extra -Error: invalid resource path: trailing components after endpoint: projects/my-project/branches/main/endpoints/primary/extra - === Provisioned flag with --project should error: >>> musterr [CLI] psql --provisioned --project foo Error: cannot use --project, --branch, or --endpoint flags with --provisioned diff --git a/acceptance/cmd/psql/argument-errors/script b/acceptance/cmd/psql/argument-errors/script index 7db1cdbd271..7806efb0744 100644 --- a/acceptance/cmd/psql/argument-errors/script +++ b/acceptance/cmd/psql/argument-errors/script @@ -38,9 +38,6 @@ trace musterr $CLI psql projects/my-project/branches/ title "Invalid path with missing endpoint ID should error:" trace musterr $CLI psql projects/my-project/branches/main/endpoints/ -title "Trailing components after endpoint should error:" -trace musterr $CLI psql projects/my-project/branches/main/endpoints/primary/extra - title "Provisioned flag with --project should error:" trace musterr $CLI psql --provisioned --project foo diff --git a/acceptance/cmd/psql/postgres/output.txt b/acceptance/cmd/psql/postgres/output.txt index 8df91c6321c..5269553a0ce 100644 --- a/acceptance/cmd/psql/postgres/output.txt +++ b/acceptance/cmd/psql/postgres/output.txt @@ -50,7 +50,7 @@ PGSSLMODE=require Project: Init Project Branch: main Endpoint: init-ep -Error: endpoint is not ready for accepting connections (state: INIT) +Error: endpoint is not ready for accepting connections === Branch flag without project should fail: >>> musterr [CLI] psql --branch some-branch diff --git a/cmd/psql/psql.go b/cmd/psql/psql.go index 9be7fb5c5df..e7f3a65f8b3 100644 --- a/cmd/psql/psql.go +++ b/cmd/psql/psql.go @@ -11,7 +11,6 @@ import ( "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdgroup" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/lakebase/target" libpsql "github.com/databricks/cli/libs/psql" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/database" @@ -87,9 +86,9 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ if argsLenAtDash < 0 { argsLenAtDash = len(args) } - targetArg := "" + target := "" if argsLenAtDash == 1 { - targetArg = args[0] + target = args[0] } else if argsLenAtDash > 1 { return errors.New("expected at most one positional argument for target") } @@ -110,17 +109,16 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ } // Positional argument takes precedence - if targetArg != "" { - if target.IsAutoscalingPath(targetArg) { + if target != "" { + if strings.HasPrefix(target, "projects/") { if provisionedFlag { return errors.New("cannot use --provisioned flag with an autoscaling resource path") } - spec, err := target.ParseAutoscalingPath(targetArg) + projectID, branchID, endpointID, err := parseResourcePath(target) if err != nil { return err } - projectID, branchID, endpointID := spec.ProjectID, spec.BranchID, spec.EndpointID // Check for conflicts between path and flags if projectFlag != "" && projectFlag != projectID { @@ -151,7 +149,7 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ if autoscalingFlag { return errors.New("cannot use --autoscaling flag with a provisioned instance name") } - return connectProvisioned(ctx, targetArg, retryConfig, extraArgs) + return connectProvisioned(ctx, target, retryConfig, extraArgs) } // No positional argument - use flags only @@ -199,6 +197,45 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ return cmd } +// parseResourcePath extracts project, branch, and endpoint IDs from a resource path. +// Returns an error for malformed paths. +func parseResourcePath(input string) (project, branch, endpoint string, err error) { + parts := strings.Split(input, "/") + + // Must start with projects/{project_id} + if len(parts) < 2 || parts[0] != "projects" { + return "", "", "", fmt.Errorf("invalid resource path: %s", input) + } + if parts[1] == "" { + return "", "", "", errors.New("invalid resource path: missing project ID") + } + project = parts[1] + + // Optional: branches/{branch_id} + if len(parts) > 2 { + if len(parts) < 4 || parts[2] != "branches" { + return "", "", "", errors.New("invalid resource path: expected 'branches' after project") + } + if parts[3] == "" { + return "", "", "", errors.New("invalid resource path: missing branch ID") + } + branch = parts[3] + } + + // Optional: endpoints/{endpoint_id} + if len(parts) > 4 { + if len(parts) < 6 || parts[4] != "endpoints" { + return "", "", "", errors.New("invalid resource path: expected 'endpoints' after branch") + } + if parts[5] == "" { + return "", "", "", errors.New("invalid resource path: missing endpoint ID") + } + endpoint = parts[5] + } + + return project, branch, endpoint, nil +} + // listAllDatabases fetches all database instances and projects in parallel. // Errors are silently ignored; callers should check for empty results. func listAllDatabases(ctx context.Context, w *databricks.WorkspaceClient) ([]database.DatabaseInstance, []postgres.Project) { @@ -211,12 +248,12 @@ func listAllDatabases(ctx context.Context, w *databricks.WorkspaceClient) ([]dat projectsCh := make(chan result[postgres.Project], 1) go func() { - instances, err := target.ListProvisionedInstances(ctx, w) + instances, err := w.Database.ListDatabaseInstancesAll(ctx, database.ListDatabaseInstancesRequest{}) instancesCh <- result[database.DatabaseInstance]{instances, err} }() go func() { - projects, err := target.ListProjects(ctx, w) + projects, err := w.Postgres.ListProjectsAll(ctx, postgres.ListProjectsRequest{}) projectsCh <- result[postgres.Project]{projects, err} }() @@ -257,7 +294,7 @@ func showSelectionAndConnect(ctx context.Context, retryConfig libpsql.RetryConfi }) } for _, proj := range projects { - displayName := target.ProjectIDFromName(proj.Name) + displayName := extractIDFromName(proj.Name, "projects") if proj.Status != nil && proj.Status.DisplayName != "" { displayName = proj.Status.DisplayName } @@ -278,7 +315,7 @@ func showSelectionAndConnect(ctx context.Context, retryConfig libpsql.RetryConfi } if after, ok := strings.CutPrefix(selected, "autoscaling:"); ok { projectName := after - projectID := target.ProjectIDFromName(projectName) + projectID := extractIDFromName(projectName, "projects") return connectAutoscaling(ctx, projectID, "", "", retryConfig, extraArgs) } diff --git a/cmd/psql/psql_autoscaling.go b/cmd/psql/psql_autoscaling.go index 04ccd4bef6b..00c555e4c12 100644 --- a/cmd/psql/psql_autoscaling.go +++ b/cmd/psql/psql_autoscaling.go @@ -4,10 +4,10 @@ import ( "context" "errors" "fmt" + "strings" "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/lakebase/target" libpsql "github.com/databricks/cli/libs/psql" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/postgres" @@ -16,6 +16,18 @@ import ( // autoscalingDefaultDatabase is the default database for Lakebase Autoscaling projects. const autoscalingDefaultDatabase = "databricks_postgres" +// extractIDFromName extracts the ID component from a resource name. +// For example, extractIDFromName("projects/foo/branches/bar", "branches") returns "bar". +func extractIDFromName(name, component string) string { + parts := strings.Split(name, "/") + for i := range len(parts) - 1 { + if parts[i] == component { + return parts[i+1] + } + } + return name +} + // connectAutoscaling connects to a Lakebase Autoscaling endpoint. func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID string, retryConfig libpsql.RetryConfig, extraArgs []string) error { w := cmdctx.WorkspaceClient(ctx) @@ -38,9 +50,11 @@ func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID str return errors.New("endpoint host information is not available") } - token, err := target.AutoscalingCredential(ctx, w, endpoint.Name) + cred, err := w.Postgres.GenerateDatabaseCredential(ctx, postgres.GenerateDatabaseCredentialRequest{ + Endpoint: endpoint.Name, + }) if err != nil { - return err + return fmt.Errorf("failed to get database credentials: %w", err) } var endpointType string @@ -61,7 +75,7 @@ func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID str case postgres.EndpointStatusStateIdle: suffix = " (idle, waking up)" default: - return fmt.Errorf("endpoint is not ready for accepting connections (state: %s)", state) + return errors.New("endpoint is not ready for accepting connections") } cmdio.LogString(ctx, fmt.Sprintf("Connecting to %s endpoint%s...", endpointType, suffix)) @@ -69,7 +83,7 @@ func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID str return libpsql.Connect(ctx, libpsql.ConnectOptions{ Host: endpoint.Status.Hosts.Host, Username: user.UserName, - Password: token, + Password: cred.Token, DefaultDatabase: autoscalingDefaultDatabase, ExtraArgs: extraArgs, }, retryConfig) @@ -88,7 +102,7 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project } // Get project to display its name - project, err := target.GetProject(ctx, w, projectID) + project, err := w.Postgres.GetProject(ctx, postgres.GetProjectRequest{Name: "projects/" + projectID}) if err != nil { return nil, fmt.Errorf("failed to get project: %w", err) } @@ -122,7 +136,7 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project } // Get endpoint to validate and return it - endpoint, err := target.GetEndpoint(ctx, w, projectID, branchID, endpointID) + endpoint, err := w.Postgres.GetEndpoint(ctx, postgres.GetEndpointRequest{Name: branch.Name + "/endpoints/" + endpointID}) if err != nil { return nil, fmt.Errorf("failed to get endpoint: %w", err) } @@ -131,40 +145,38 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project return endpoint, nil } -// selectAmbiguous prompts the user to pick one of the choices in an -// AmbiguousError. Caller is expected to have logged a header (e.g. via the -// spinner) before invoking. Used to keep psql's interactive UX while letting -// the shared lib do the actual list+filter work. -// -// Choice.DisplayName is empty when the producer has no friendlier label than -// the ID (e.g. branches and endpoints, where the ID is the human label). -// The promptui template renders an empty Name as a blank row, so we fall back -// to the ID before handing off to cmdio.SelectOrdered. -func selectAmbiguous(ctx context.Context, amb *target.AmbiguousError, prompt string) (string, error) { - items := make([]cmdio.Tuple, 0, len(amb.Choices)) - for _, c := range amb.Choices { - name := c.DisplayName - if name == "" { - name = c.ID - } - items = append(items, cmdio.Tuple{Name: name, Id: c.ID}) - } - return cmdio.SelectOrdered(ctx, items, prompt) -} - // selectProjectID auto-selects if there's only one project, otherwise prompts user to select. // Returns the project ID (not the full project object). func selectProjectID(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading projects...") - id, err := target.AutoSelectProject(ctx, w) + projects, err := w.Postgres.ListProjectsAll(ctx, postgres.ListProjectsRequest{}) sp.Close() + if err != nil { + return "", err + } + + if len(projects) == 0 { + return "", errors.New("no Lakebase Autoscaling projects found in workspace") + } + + // Auto-select if there's only one project + if len(projects) == 1 { + return extractIDFromName(projects[0].Name, "projects"), nil + } - var amb *target.AmbiguousError - if !errors.As(err, &amb) { - return id, err + // Multiple projects, prompt user to select + var items []cmdio.Tuple + for _, project := range projects { + projectID := extractIDFromName(project.Name, "projects") + displayName := projectID + if project.Status != nil && project.Status.DisplayName != "" { + displayName = project.Status.DisplayName + } + items = append(items, cmdio.Tuple{Name: displayName, Id: projectID}) } - return selectAmbiguous(ctx, amb, "Select project") + + return cmdio.SelectOrdered(ctx, items, "Select project") } // selectBranchID auto-selects if there's only one branch, otherwise prompts user to select. @@ -172,14 +184,31 @@ func selectProjectID(ctx context.Context, w *databricks.WorkspaceClient) (string func selectBranchID(ctx context.Context, w *databricks.WorkspaceClient, projectName string) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading branches...") - id, err := target.AutoSelectBranch(ctx, w, projectName) + branches, err := w.Postgres.ListBranchesAll(ctx, postgres.ListBranchesRequest{ + Parent: projectName, + }) sp.Close() + if err != nil { + return "", err + } + + if len(branches) == 0 { + return "", errors.New("no branches found in project") + } + + // Auto-select if there's only one branch + if len(branches) == 1 { + return extractIDFromName(branches[0].Name, "branches"), nil + } - var amb *target.AmbiguousError - if !errors.As(err, &amb) { - return id, err + // Multiple branches, prompt user to select + var items []cmdio.Tuple + for _, branch := range branches { + branchID := extractIDFromName(branch.Name, "branches") + items = append(items, cmdio.Tuple{Name: branchID, Id: branchID}) } - return selectAmbiguous(ctx, amb, "Select branch") + + return cmdio.SelectOrdered(ctx, items, "Select branch") } // selectEndpointID auto-selects if there's only one endpoint, otherwise prompts user to select. @@ -187,12 +216,29 @@ func selectBranchID(ctx context.Context, w *databricks.WorkspaceClient, projectN func selectEndpointID(ctx context.Context, w *databricks.WorkspaceClient, branchName string) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading endpoints...") - id, err := target.AutoSelectEndpoint(ctx, w, branchName) + endpoints, err := w.Postgres.ListEndpointsAll(ctx, postgres.ListEndpointsRequest{ + Parent: branchName, + }) sp.Close() + if err != nil { + return "", err + } + + if len(endpoints) == 0 { + return "", errors.New("no endpoints found in branch") + } - var amb *target.AmbiguousError - if !errors.As(err, &amb) { - return id, err + // Auto-select if there's only one endpoint + if len(endpoints) == 1 { + return extractIDFromName(endpoints[0].Name, "endpoints"), nil } - return selectAmbiguous(ctx, amb, "Select endpoint") + + // Multiple endpoints, prompt user to select + var items []cmdio.Tuple + for _, endpoint := range endpoints { + endpointID := extractIDFromName(endpoint.Name, "endpoints") + items = append(items, cmdio.Tuple{Name: endpointID, Id: endpointID}) + } + + return cmdio.SelectOrdered(ctx, items, "Select endpoint") } diff --git a/cmd/psql/psql_provisioned.go b/cmd/psql/psql_provisioned.go index c7208906aa8..88ca1bb9181 100644 --- a/cmd/psql/psql_provisioned.go +++ b/cmd/psql/psql_provisioned.go @@ -7,10 +7,10 @@ import ( "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/lakebase/target" libpsql "github.com/databricks/cli/libs/psql" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/database" + "github.com/google/uuid" ) // provisionedDefaultDatabase is the default database for Lakebase Provisioned instances. @@ -39,9 +39,12 @@ func connectProvisioned(ctx context.Context, instanceName string, retryConfig li return errors.New("database instance is not ready for accepting connections") } - token, err := target.ProvisionedCredential(ctx, w, instance.Name) + cred, err := w.Database.GenerateDatabaseCredential(ctx, database.GenerateDatabaseCredentialRequest{ + InstanceNames: []string{instance.Name}, + RequestId: uuid.NewString(), + }) if err != nil { - return err + return fmt.Errorf("failed to get database credentials: %w", err) } cmdio.LogString(ctx, "Connecting to database instance...") @@ -49,7 +52,7 @@ func connectProvisioned(ctx context.Context, instanceName string, retryConfig li return libpsql.Connect(ctx, libpsql.ConnectOptions{ Host: instance.ReadWriteDns, Username: user.UserName, - Password: token, + Password: cred.Token, DefaultDatabase: provisionedDefaultDatabase, ExtraArgs: extraArgs, }, retryConfig) @@ -67,11 +70,15 @@ func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanc } } - // target.GetProvisioned patches Name on the response; the SDK's - // GetDatabaseInstance does not always populate it. - instance, err := target.GetProvisioned(ctx, w, instanceName) + instance, err := w.Database.GetDatabaseInstance(ctx, database.GetDatabaseInstanceRequest{ + Name: instanceName, + }) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get database instance: %w", err) + } + // Ensure Name is set (API response may not include it) + if instance.Name == "" { + instance.Name = instanceName } cmdio.LogString(ctx, "Instance: "+instance.Name) @@ -83,12 +90,26 @@ func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanc func selectInstanceID(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading instances...") - id, err := target.AutoSelectProvisioned(ctx, w) + instances, err := w.Database.ListDatabaseInstancesAll(ctx, database.ListDatabaseInstancesRequest{}) sp.Close() + if err != nil { + return "", err + } - var amb *target.AmbiguousError - if !errors.As(err, &amb) { - return id, err + if len(instances) == 0 { + return "", errors.New("no Lakebase Provisioned instances found in workspace") } - return selectAmbiguous(ctx, amb, "Select instance") + + // Auto-select if there's only one instance + if len(instances) == 1 { + return instances[0].Name, nil + } + + // Multiple instances, prompt user to select + var items []cmdio.Tuple + for _, inst := range instances { + items = append(items, cmdio.Tuple{Name: inst.Name, Id: inst.Name}) + } + + return cmdio.SelectOrdered(ctx, items, "Select instance") } diff --git a/cmd/psql/psql_test.go b/cmd/psql/psql_test.go new file mode 100644 index 00000000000..fc8a7e53cba --- /dev/null +++ b/cmd/psql/psql_test.go @@ -0,0 +1,83 @@ +package psql + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseResourcePath(t *testing.T) { + tests := []struct { + name string + input string + project string + branch string + endpoint string + wantErr string + }{ + { + name: "project only", + input: "projects/my-project", + project: "my-project", + }, + { + name: "project and branch", + input: "projects/my-project/branches/main", + project: "my-project", + branch: "main", + }, + { + name: "full path", + input: "projects/my-project/branches/main/endpoints/primary", + project: "my-project", + branch: "main", + endpoint: "primary", + }, + { + name: "missing project ID", + input: "projects/", + wantErr: "missing project ID", + }, + { + name: "missing branch ID", + input: "projects/my-project/branches/", + wantErr: "missing branch ID", + }, + { + name: "missing endpoint ID", + input: "projects/my-project/branches/main/endpoints/", + wantErr: "missing endpoint ID", + }, + { + name: "invalid segment after project", + input: "projects/my-project/invalid/foo", + wantErr: "expected 'branches' after project", + }, + { + name: "invalid segment after branch", + input: "projects/my-project/branches/main/invalid/foo", + wantErr: "expected 'endpoints' after branch", + }, + { + name: "not a projects path", + input: "something/else", + wantErr: "invalid resource path", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + project, branch, endpoint, err := parseResourcePath(tc.input) + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tc.project, project) + assert.Equal(t, tc.branch, branch) + assert.Equal(t, tc.endpoint, endpoint) + }) + } +} diff --git a/libs/lakebase/target/autoscaling.go b/experimental/postgres/cmd/internal/target/autoscaling.go similarity index 100% rename from libs/lakebase/target/autoscaling.go rename to experimental/postgres/cmd/internal/target/autoscaling.go diff --git a/libs/lakebase/target/provisioned.go b/experimental/postgres/cmd/internal/target/provisioned.go similarity index 100% rename from libs/lakebase/target/provisioned.go rename to experimental/postgres/cmd/internal/target/provisioned.go diff --git a/libs/lakebase/target/target.go b/experimental/postgres/cmd/internal/target/target.go similarity index 100% rename from libs/lakebase/target/target.go rename to experimental/postgres/cmd/internal/target/target.go diff --git a/libs/lakebase/target/target_test.go b/experimental/postgres/cmd/internal/target/target_test.go similarity index 100% rename from libs/lakebase/target/target_test.go rename to experimental/postgres/cmd/internal/target/target_test.go diff --git a/experimental/postgres/cmd/targeting.go b/experimental/postgres/cmd/targeting.go index 5e72840f952..7f6a6830daa 100644 --- a/experimental/postgres/cmd/targeting.go +++ b/experimental/postgres/cmd/targeting.go @@ -5,8 +5,8 @@ import ( "errors" "fmt" + "github.com/databricks/cli/experimental/postgres/cmd/internal/target" "github.com/databricks/cli/libs/cmdctx" - "github.com/databricks/cli/libs/lakebase/target" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/postgres" ) From 1a0798825ce9be527e1f7f796dc640bcd9e797a4 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 13:11:53 +0200 Subject: [PATCH 16/25] Extract output-mode handling into experimental/libs/sqlcli Both aitools query and postgres query had near-identical output-mode selection: same DATABRICKS_OUTPUT_FORMAT env var, same flag-vs-env precedence, same staticTableThreshold=30, same Format type with text/json/csv values. Promote the shared bits to experimental/libs/sqlcli: - sqlcli.EnvOutputFormat, sqlcli.StaticTableThreshold consts - sqlcli.Format typedef + sqlcli.OutputText/JSON/CSV consts - sqlcli.AllFormats slice (canonical order for completions) - sqlcli.ResolveFormat: handles flag > env > default precedence with the explicit-text-on-pipe-is-honoured rule Both consumers now import sqlcli. The package lives under experimental/libs/ rather than libs/ so it inherits the experimental- stability guarantee of its consumers; when both commands graduate, the package can be promoted alongside. The aitools migration is a pure refactor (no behavior change). The postgres command's output.go and output_test.go are deleted; tests moved to experimental/libs/sqlcli. Co-authored-by: Isaac --- experimental/aitools/cmd/query.go | 65 ++++++--------- experimental/aitools/cmd/query_test.go | 22 ++--- experimental/libs/sqlcli/output.go | 93 +++++++++++++++++++++ experimental/libs/sqlcli/output_test.go | 100 +++++++++++++++++++++++ experimental/postgres/cmd/output.go | 79 ------------------ experimental/postgres/cmd/output_test.go | 93 --------------------- experimental/postgres/cmd/query.go | 15 ++-- 7 files changed, 239 insertions(+), 228 deletions(-) create mode 100644 experimental/libs/sqlcli/output.go create mode 100644 experimental/libs/sqlcli/output_test.go delete mode 100644 experimental/postgres/cmd/output.go delete mode 100644 experimental/postgres/cmd/output_test.go diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 7e9ae1d030d..45c5669c699 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -14,10 +14,9 @@ import ( "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/experimental/aitools/lib/middlewares" "github.com/databricks/cli/experimental/aitools/lib/session" + "github.com/databricks/cli/experimental/libs/sqlcli" "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/env" - "github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" @@ -35,16 +34,6 @@ const ( // cancelTimeout is how long to wait for server-side cancellation. cancelTimeout = 10 * time.Second - - // staticTableThreshold is the maximum number of rows rendered as a static table. - // Beyond this, an interactive scrollable table is used. - staticTableThreshold = 30 - - // outputCSV is the csv output format, supported only by the query command. - outputCSV = "csv" - - // envOutputFormat matches the env var name in cmd/root/io.go. - envOutputFormat = "DATABRICKS_OUTPUT_FORMAT" ) type queryOutputMode int @@ -55,8 +44,13 @@ const ( queryOutputModeInteractiveTable ) -func selectQueryOutputMode(outputType flags.Output, stdoutInteractive, promptSupported bool, rowCount int) queryOutputMode { - if outputType == flags.OutputJSON { +// selectQueryOutputMode picks the rendering mode for a single-query result. +// JSON is the only machine-readable option; static and interactive are +// table variants chosen by row count and TTY capabilities. Sharing only +// the threshold with sqlcli; the three-way decision is aitools-specific +// because the postgres command's renderers have a different shape. +func selectQueryOutputMode(format sqlcli.Format, stdoutInteractive, promptSupported bool, rowCount int) queryOutputMode { + if format == sqlcli.OutputJSON { return queryOutputModeJSON } if !stdoutInteractive { @@ -67,7 +61,7 @@ func selectQueryOutputMode(outputType flags.Output, stdoutInteractive, promptSup if !promptSupported { return queryOutputModeStaticTable } - if rowCount <= staticTableThreshold { + if rowCount <= sqlcli.StaticTableThreshold { return queryOutputModeStaticTable } return queryOutputModeInteractiveTable @@ -119,24 +113,15 @@ interactive table browser. Use --output csv to export results as CSV.`, RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - // Normalize case to match root --output behavior (flags.Output.Set lowercases). - outputFormat = strings.ToLower(outputFormat) - - // If --output wasn't explicitly passed, check the env var. - // Invalid env values are silently ignored, matching cmd/root/io.go. - if !cmd.Flag("output").Changed { - if v, ok := env.Lookup(ctx, envOutputFormat); ok { - switch flags.Output(strings.ToLower(v)) { - case flags.OutputText, flags.OutputJSON, outputCSV: - outputFormat = strings.ToLower(v) - } - } - } - - switch flags.Output(outputFormat) { - case flags.OutputText, flags.OutputJSON, outputCSV: - default: - return fmt.Errorf("unsupported output format %q, accepted values: text, json, csv", outputFormat) + // Resolve the effective format via sqlcli so the env-var + // precedence and explicit-text-on-pipe handling stays in sync + // across commands. We pass stdoutTTY=true to keep the original + // aitools behavior of not auto-falling-back to JSON here; the + // per-result render mode further down already handles the pipe + // case via selectQueryOutputMode. + format, err := sqlcli.ResolveFormat(ctx, outputFormat, cmd.Flag("output").Changed, true) + if err != nil { + return err } sqls, err := resolveSQLs(ctx, cmd, args, filePaths) @@ -146,7 +131,7 @@ interactive table browser. Use --output csv to export results as CSV.`, // Reject incompatible flag combinations before any API call so the // user sees the real error instead of an auth/warehouse failure. - if len(sqls) > 1 && flags.Output(outputFormat) != flags.OutputJSON { + if len(sqls) > 1 && format != sqlcli.OutputJSON { return fmt.Errorf("multiple queries require --output json (got %q); pass --output json to receive a JSON array of per-statement results", outputFormat) } @@ -173,7 +158,7 @@ interactive table browser. Use --output csv to export results as CSV.`, } // CSV bypasses the normal output mode selection. - if flags.Output(outputFormat) == outputCSV { + if format == sqlcli.OutputCSV { if len(columns) == 0 && len(rows) == 0 { return nil } @@ -190,7 +175,7 @@ interactive table browser. Use --output csv to export results as CSV.`, stdoutInteractive := cmdio.SupportsColor(ctx, cmd.OutOrStdout()) promptSupported := cmdio.IsPromptSupported(ctx) - switch selectQueryOutputMode(flags.Output(outputFormat), stdoutInteractive, promptSupported, len(rows)) { + switch selectQueryOutputMode(format, stdoutInteractive, promptSupported, len(rows)) { case queryOutputModeJSON: return renderJSON(cmd.OutOrStdout(), columns, rows) case queryOutputModeStaticTable: @@ -206,9 +191,13 @@ interactive table browser. Use --output csv to export results as CSV.`, cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum in-flight statements when running a batch of queries") // Local --output flag shadows the root command's persistent --output flag, // adding csv support for this command only. - cmd.Flags().StringVarP(&outputFormat, "output", "o", string(flags.OutputText), "Output format: text, json, or csv") + cmd.Flags().StringVarP(&outputFormat, "output", "o", string(sqlcli.OutputText), "Output format: text, json, or csv") cmd.RegisterFlagCompletionFunc("output", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { - return []string{string(flags.OutputText), string(flags.OutputJSON), string(outputCSV)}, cobra.ShellCompDirectiveNoFileComp + out := make([]string, len(sqlcli.AllFormats)) + for i, f := range sqlcli.AllFormats { + out[i] = string(f) + } + return out, cobra.ShellCompDirectiveNoFileComp }) return cmd diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index 59de11d578a..c85edc64722 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -10,9 +10,9 @@ import ( "time" "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/experimental/libs/sqlcli" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" - "github.com/databricks/cli/libs/flags" mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" @@ -271,7 +271,7 @@ func TestResolveWarehouseIDWithFlag(t *testing.T) { func TestSelectQueryOutputMode(t *testing.T) { tests := []struct { name string - outputType flags.Output + format sqlcli.Format stdoutInteractive bool promptSupported bool rowCount int @@ -279,7 +279,7 @@ func TestSelectQueryOutputMode(t *testing.T) { }{ { name: "json flag always returns json", - outputType: flags.OutputJSON, + format: sqlcli.OutputJSON, stdoutInteractive: true, promptSupported: true, rowCount: 999, @@ -287,7 +287,7 @@ func TestSelectQueryOutputMode(t *testing.T) { }, { name: "non interactive stdout returns json", - outputType: flags.OutputText, + format: sqlcli.OutputText, stdoutInteractive: false, promptSupported: true, rowCount: 5, @@ -295,33 +295,33 @@ func TestSelectQueryOutputMode(t *testing.T) { }, { name: "missing stdin interactivity falls back to static table", - outputType: flags.OutputText, + format: sqlcli.OutputText, stdoutInteractive: true, promptSupported: false, - rowCount: staticTableThreshold + 10, + rowCount: sqlcli.StaticTableThreshold + 10, want: queryOutputModeStaticTable, }, { name: "small results use static table", - outputType: flags.OutputText, + format: sqlcli.OutputText, stdoutInteractive: true, promptSupported: true, - rowCount: staticTableThreshold, + rowCount: sqlcli.StaticTableThreshold, want: queryOutputModeStaticTable, }, { name: "large results use interactive table", - outputType: flags.OutputText, + format: sqlcli.OutputText, stdoutInteractive: true, promptSupported: true, - rowCount: staticTableThreshold + 1, + rowCount: sqlcli.StaticTableThreshold + 1, want: queryOutputModeInteractiveTable, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got := selectQueryOutputMode(tc.outputType, tc.stdoutInteractive, tc.promptSupported, tc.rowCount) + got := selectQueryOutputMode(tc.format, tc.stdoutInteractive, tc.promptSupported, tc.rowCount) assert.Equal(t, tc.want, got) }) } diff --git a/experimental/libs/sqlcli/output.go b/experimental/libs/sqlcli/output.go new file mode 100644 index 00000000000..4643303cd23 --- /dev/null +++ b/experimental/libs/sqlcli/output.go @@ -0,0 +1,93 @@ +// Package sqlcli holds patterns shared by experimental SQL-running commands +// (currently `experimental aitools tools query` and `experimental postgres +// query`). The package lives under experimental/libs/ rather than libs/ so +// the commands depending on it inherit experimental-stability guarantees: +// when both consumers graduate, this package can be promoted alongside +// (or its API stabilised first). +package sqlcli + +import ( + "context" + "fmt" + "slices" + "strings" + + "github.com/databricks/cli/libs/env" +) + +// EnvOutputFormat matches the env var name in cmd/root/io.go. +// Reading it lets pipelines set DATABRICKS_OUTPUT_FORMAT once for all +// commands. +const EnvOutputFormat = "DATABRICKS_OUTPUT_FORMAT" + +// StaticTableThreshold is the row count above which interactive callers may +// hand off to libs/tableview's scrollable viewer. Smaller results stay in a +// static tabwriter table so they pipe to scripts unchanged. +const StaticTableThreshold = 30 + +// Format is the user-selectable output shape. Using a string typedef instead +// of an int enum keeps the help text and DATABRICKS_OUTPUT_FORMAT env var +// values self-describing. +type Format string + +const ( + OutputText Format = "text" + OutputJSON Format = "json" + OutputCSV Format = "csv" +) + +// AllFormats is the canonical order shown in completions / help. Sharing +// the slice avoids drift between consumers when a new format is added. +var AllFormats = []Format{OutputText, OutputJSON, OutputCSV} + +// ResolveFormat picks the effective output format. Precedence: +// +// 1. The local --output flag if it was explicitly set. +// 2. DATABRICKS_OUTPUT_FORMAT env var if set to a known value (invalid +// values are silently ignored, matching cmd/root/io.go and aitools). +// 3. The flag default (whatever the caller passes as flagValue). +// +// Then the auto-selection rule applies: a *defaulted* text mode on a non-TTY +// stdout falls back to JSON, so scripts piping the output get machine- +// readable output by default. An *explicit* --output text (flag or env) is +// honoured even on a pipe; per AGENTS.md we don't silently override flags +// the user set. +// +// flagSet is true if the user explicitly passed --output on the CLI. +// stdoutTTY is true if stdout is a terminal. +func ResolveFormat(ctx context.Context, flagValue string, flagSet, stdoutTTY bool) (Format, error) { + chosen := Format(strings.ToLower(flagValue)) + chosenExplicit := flagSet + + if !flagSet { + if v, ok := env.Lookup(ctx, EnvOutputFormat); ok { + candidate := Format(strings.ToLower(v)) + if IsKnown(candidate) { + chosen = candidate + chosenExplicit = true + } + } + } + + if !IsKnown(chosen) { + return "", fmt.Errorf("unsupported output format %q; expected one of: %s", flagValue, joinFormats(AllFormats)) + } + + if chosen == OutputText && !stdoutTTY && !chosenExplicit { + return OutputJSON, nil + } + return chosen, nil +} + +// IsKnown reports whether f is one of the formats in AllFormats. +func IsKnown(f Format) bool { + return slices.Contains(AllFormats, f) +} + +func joinFormats(formats []Format) string { + parts := make([]string, len(formats)) + for i, f := range formats { + parts[i] = string(f) + } + return strings.Join(parts, ", ") +} diff --git a/experimental/libs/sqlcli/output_test.go b/experimental/libs/sqlcli/output_test.go new file mode 100644 index 00000000000..1e91bd9cf3d --- /dev/null +++ b/experimental/libs/sqlcli/output_test.go @@ -0,0 +1,100 @@ +package sqlcli + +import ( + "testing" + + "github.com/databricks/cli/libs/env" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestResolveFormat_Defaults(t *testing.T) { + ctx := t.Context() + got, err := ResolveFormat(ctx, "text", false, true) + require.NoError(t, err) + assert.Equal(t, OutputText, got) +} + +func TestResolveFormat_TextOnPipeFallsBackToJSON(t *testing.T) { + ctx := t.Context() + got, err := ResolveFormat(ctx, "text", false, false) + require.NoError(t, err) + assert.Equal(t, OutputJSON, got) +} + +func TestResolveFormat_ExplicitTextOnPipeIsHonoured(t *testing.T) { + ctx := t.Context() + got, err := ResolveFormat(ctx, "text", true, false) + require.NoError(t, err) + assert.Equal(t, OutputText, got) +} + +func TestResolveFormat_EnvVarTextOnPipeIsHonoured(t *testing.T) { + ctx := env.Set(t.Context(), EnvOutputFormat, "text") + got, err := ResolveFormat(ctx, "text", false, false) + require.NoError(t, err) + assert.Equal(t, OutputText, got) +} + +func TestResolveFormat_EnvVarCSVOnPipe(t *testing.T) { + ctx := env.Set(t.Context(), EnvOutputFormat, "csv") + got, err := ResolveFormat(ctx, "text", false, false) + require.NoError(t, err) + assert.Equal(t, OutputCSV, got) +} + +func TestResolveFormat_ExplicitJSON(t *testing.T) { + ctx := t.Context() + got, err := ResolveFormat(ctx, "json", true, true) + require.NoError(t, err) + assert.Equal(t, OutputJSON, got) +} + +func TestResolveFormat_ExplicitCSV(t *testing.T) { + ctx := t.Context() + got, err := ResolveFormat(ctx, "csv", true, true) + require.NoError(t, err) + assert.Equal(t, OutputCSV, got) +} + +func TestResolveFormat_EnvVarHonoredWhenFlagNotSet(t *testing.T) { + ctx := env.Set(t.Context(), EnvOutputFormat, "csv") + got, err := ResolveFormat(ctx, "text", false, true) + require.NoError(t, err) + assert.Equal(t, OutputCSV, got) +} + +func TestResolveFormat_FlagOverridesEnvVar(t *testing.T) { + ctx := env.Set(t.Context(), EnvOutputFormat, "csv") + got, err := ResolveFormat(ctx, "json", true, true) + require.NoError(t, err) + assert.Equal(t, OutputJSON, got) +} + +func TestResolveFormat_InvalidEnvVarIgnored(t *testing.T) { + ctx := env.Set(t.Context(), EnvOutputFormat, "yaml") + got, err := ResolveFormat(ctx, "text", false, true) + require.NoError(t, err) + assert.Equal(t, OutputText, got) +} + +func TestResolveFormat_InvalidFlagErrors(t *testing.T) { + ctx := t.Context() + _, err := ResolveFormat(ctx, "yaml", true, true) + assert.ErrorContains(t, err, "unsupported output format") +} + +func TestResolveFormat_CaseInsensitive(t *testing.T) { + ctx := t.Context() + got, err := ResolveFormat(ctx, "JSON", true, true) + require.NoError(t, err) + assert.Equal(t, OutputJSON, got) +} + +func TestIsKnown(t *testing.T) { + assert.True(t, IsKnown(OutputText)) + assert.True(t, IsKnown(OutputJSON)) + assert.True(t, IsKnown(OutputCSV)) + assert.False(t, IsKnown(Format("yaml"))) + assert.False(t, IsKnown(Format(""))) +} diff --git a/experimental/postgres/cmd/output.go b/experimental/postgres/cmd/output.go deleted file mode 100644 index e5b59fec96f..00000000000 --- a/experimental/postgres/cmd/output.go +++ /dev/null @@ -1,79 +0,0 @@ -package postgrescmd - -import ( - "context" - "fmt" - "slices" - "strings" - - "github.com/databricks/cli/libs/env" -) - -// outputFormat is the user-selectable output shape. Using a string typedef -// instead of an int enum keeps the help text and DATABRICKS_OUTPUT_FORMAT env -// var values self-describing. -type outputFormat string - -const ( - outputText outputFormat = "text" - outputJSON outputFormat = "json" - outputCSV outputFormat = "csv" - - // envOutputFormat matches the env var name in cmd/root/io.go. Reading it - // here lets pipelines set DATABRICKS_OUTPUT_FORMAT once for all - // commands. See aitools query for a parallel pattern. - envOutputFormat = "DATABRICKS_OUTPUT_FORMAT" -) - -// allOutputFormats is the canonical order shown in completions / help. -var allOutputFormats = []outputFormat{outputText, outputJSON, outputCSV} - -// resolveOutputFormat picks the effective output format. Precedence: -// -// 1. The local --output flag if it was explicitly set. -// 2. DATABRICKS_OUTPUT_FORMAT env var if set to a known value (invalid -// values are silently ignored, matching cmd/root/io.go and aitools). -// 3. The flag default ("text"). -// -// Then the auto-selection rule applies: a *defaulted* text mode on a non-TTY -// stdout falls back to JSON, so scripts piping the output get machine- -// readable output by default. An *explicit* --output text is honoured even -// on a pipe; per CLAUDE.md we don't silently override flags the user set. -// -// flagSet is true if the user explicitly passed --output. stdoutTTY is true -// if stdout is a terminal. -func resolveOutputFormat(ctx context.Context, flagValue string, flagSet, stdoutTTY bool) (outputFormat, error) { - chosen := outputFormat(strings.ToLower(flagValue)) - chosenExplicit := flagSet - - if !flagSet { - if v, ok := env.Lookup(ctx, envOutputFormat); ok { - candidate := outputFormat(strings.ToLower(v)) - if isKnownOutputFormat(candidate) { - chosen = candidate - chosenExplicit = true - } - } - } - - if !isKnownOutputFormat(chosen) { - return "", fmt.Errorf("unsupported output format %q; expected one of: %s", flagValue, joinOutputFormats(allOutputFormats)) - } - - if chosen == outputText && !stdoutTTY && !chosenExplicit { - return outputJSON, nil - } - return chosen, nil -} - -func joinOutputFormats(formats []outputFormat) string { - parts := make([]string, len(formats)) - for i, f := range formats { - parts[i] = string(f) - } - return strings.Join(parts, ", ") -} - -func isKnownOutputFormat(f outputFormat) bool { - return slices.Contains(allOutputFormats, f) -} diff --git a/experimental/postgres/cmd/output_test.go b/experimental/postgres/cmd/output_test.go deleted file mode 100644 index 4598085805a..00000000000 --- a/experimental/postgres/cmd/output_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package postgrescmd - -import ( - "testing" - - "github.com/databricks/cli/libs/env" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestResolveOutputFormat_Defaults(t *testing.T) { - ctx := t.Context() - - got, err := resolveOutputFormat(ctx, "text", false, true) - require.NoError(t, err) - assert.Equal(t, outputText, got) -} - -func TestResolveOutputFormat_TextOnPipeFallsBackToJSON(t *testing.T) { - ctx := t.Context() - got, err := resolveOutputFormat(ctx, "text", false, false) - require.NoError(t, err) - assert.Equal(t, outputJSON, got) -} - -func TestResolveOutputFormat_ExplicitTextOnPipeIsHonoured(t *testing.T) { - ctx := t.Context() - got, err := resolveOutputFormat(ctx, "text", true, false) - require.NoError(t, err) - assert.Equal(t, outputText, got) -} - -func TestResolveOutputFormat_EnvVarTextOnPipeIsHonoured(t *testing.T) { - ctx := env.Set(t.Context(), envOutputFormat, "text") - got, err := resolveOutputFormat(ctx, "text", false, false) - require.NoError(t, err) - assert.Equal(t, outputText, got) -} - -func TestResolveOutputFormat_EnvVarCSVOnPipe(t *testing.T) { - ctx := env.Set(t.Context(), envOutputFormat, "csv") - got, err := resolveOutputFormat(ctx, "text", false, false) - require.NoError(t, err) - assert.Equal(t, outputCSV, got) -} - -func TestResolveOutputFormat_ExplicitJSON(t *testing.T) { - ctx := t.Context() - got, err := resolveOutputFormat(ctx, "json", true, true) - require.NoError(t, err) - assert.Equal(t, outputJSON, got) -} - -func TestResolveOutputFormat_ExplicitCSV(t *testing.T) { - ctx := t.Context() - got, err := resolveOutputFormat(ctx, "csv", true, true) - require.NoError(t, err) - assert.Equal(t, outputCSV, got) -} - -func TestResolveOutputFormat_EnvVarHonoredWhenFlagNotSet(t *testing.T) { - ctx := env.Set(t.Context(), envOutputFormat, "csv") - got, err := resolveOutputFormat(ctx, "text", false, true) - require.NoError(t, err) - assert.Equal(t, outputCSV, got) -} - -func TestResolveOutputFormat_FlagOverridesEnvVar(t *testing.T) { - ctx := env.Set(t.Context(), envOutputFormat, "csv") - got, err := resolveOutputFormat(ctx, "json", true, true) - require.NoError(t, err) - assert.Equal(t, outputJSON, got) -} - -func TestResolveOutputFormat_InvalidEnvVarIgnored(t *testing.T) { - ctx := env.Set(t.Context(), envOutputFormat, "yaml") - got, err := resolveOutputFormat(ctx, "text", false, true) - require.NoError(t, err) - assert.Equal(t, outputText, got) -} - -func TestResolveOutputFormat_InvalidFlagErrors(t *testing.T) { - ctx := t.Context() - _, err := resolveOutputFormat(ctx, "yaml", true, true) - assert.ErrorContains(t, err, "unsupported output format") -} - -func TestResolveOutputFormat_CaseInsensitive(t *testing.T) { - ctx := t.Context() - got, err := resolveOutputFormat(ctx, "JSON", true, true) - require.NoError(t, err) - assert.Equal(t, outputJSON, got) -} diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index 2b4f12694f9..5a7f3e577cc 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -9,6 +9,7 @@ import ( "time" "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/experimental/libs/sqlcli" "github.com/databricks/cli/libs/cmdio" "github.com/jackc/pgx/v5" "github.com/spf13/cobra" @@ -89,10 +90,10 @@ Limitations (this release): cmd.Flags().StringVarP(&f.database, "database", "d", defaultDatabase, "Database name") cmd.Flags().DurationVar(&f.connectTimeout, "connect-timeout", defaultConnectTimeout, "Connect timeout") cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (must be >= 1; 1 disables retry)") - cmd.Flags().StringVarP(&f.outputFormat, "output", "o", string(outputText), "Output format: text, json, or csv") + cmd.Flags().StringVarP(&f.outputFormat, "output", "o", string(sqlcli.OutputText), "Output format: text, json, or csv") cmd.RegisterFlagCompletionFunc("output", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { - out := make([]string, len(allOutputFormats)) - for i, f := range allOutputFormats { + out := make([]string, len(sqlcli.AllFormats)) + for i, f := range sqlcli.AllFormats { out[i] = string(f) } return out, cobra.ShellCompDirectiveNoFileComp @@ -126,7 +127,7 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) // pass --output text explicitly; that path is honoured (see // resolveOutputFormat). Mirrors the aitools query command. stdoutTTY := cmdio.SupportsColor(ctx, cmd.OutOrStdout()) - format, err := resolveOutputFormat(ctx, f.outputFormat, f.outputFormatSet, stdoutTTY) + format, err := sqlcli.ResolveFormat(ctx, f.outputFormat, f.outputFormatSet, stdoutTTY) if err != nil { return err } @@ -168,11 +169,11 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) // newSink returns the rowSink for the chosen output format. Kept separate // from runQuery so tests can build sinks without going through pgx. -func newSink(format outputFormat, out, stderr io.Writer) rowSink { +func newSink(format sqlcli.Format, out, stderr io.Writer) rowSink { switch format { - case outputJSON: + case sqlcli.OutputJSON: return newJSONSink(out, stderr) - case outputCSV: + case sqlcli.OutputCSV: return newCSVSink(out, stderr) default: return newTextSink(out) From a5dff81d25b4ac21d390935ee5061e4972dc72e3 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 15:16:57 +0200 Subject: [PATCH 17/25] Address nitpicker findings: NO_COLOR-safe TTY check, dup-column collision, control-char escape Three P2 findings from the nitpicker bot, all in code introduced or strengthened in this PR: - stdoutTTY now uses cmdio.IsOutputTTY (a new tiny public helper that wraps the existing private isTTY) instead of cmdio.SupportsColor. SupportsColor folds in NO_COLOR / TERM=dumb, which are colour preferences and have nothing to do with whether stdout is a pipe; using it for the auto-fall-back-to-JSON decision silently demoted interactive text output to JSON for users with NO_COLOR set on a real terminal. IsOutputTTY is the right primitive for this. - jsonSink dup-column rename: the previous logic generated id__2 for the second `id` without checking whether id__2 was already taken by the original column list. A query returning ["id", "id__2", "id"] produced two id__2 keys. Now we keep bumping the suffix until unique. - textSink escapes \t, \n, \r in cell values before tabwriter sees them. tabwriter uses \t as a column boundary and \n as a row boundary, so an embedded tab silently shifted subsequent columns and an embedded newline split a logical row across multiple output lines. psql does the same backslash-letter escape. Co-authored-by: Isaac --- experimental/postgres/cmd/query.go | 12 +++++----- experimental/postgres/cmd/render.go | 15 +++++++++++- experimental/postgres/cmd/render_json.go | 20 ++++++++++++---- experimental/postgres/cmd/render_json_test.go | 23 +++++++++++++++++++ experimental/postgres/cmd/render_test.go | 13 +++++++++++ libs/cmdio/tty.go | 10 ++++++++ 6 files changed, 82 insertions(+), 11 deletions(-) diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index 5a7f3e577cc..f05b3e01503 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -121,12 +121,12 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) return err } - // SupportsColor is the public TTY-ish signal libs/cmdio exposes today; it - // also folds in NO_COLOR / TERM=dumb, which strictly speaking are colour - // preferences rather than TTY signals. Users who hit that edge case can - // pass --output text explicitly; that path is honoured (see - // resolveOutputFormat). Mirrors the aitools query command. - stdoutTTY := cmdio.SupportsColor(ctx, cmd.OutOrStdout()) + // IsOutputTTY checks the file-descriptor only. SupportsColor would also + // AND in NO_COLOR / TERM=dumb, which are colour preferences and have + // nothing to do with whether stdout is a pipe; folding them in here + // would silently demote interactive text output to JSON for users who + // have NO_COLOR set on a real terminal. + stdoutTTY := cmdio.IsOutputTTY(cmd.OutOrStdout()) format, err := sqlcli.ResolveFormat(ctx, f.outputFormat, f.outputFormatSet, stdoutTTY) if err != nil { return err diff --git a/experimental/postgres/cmd/render.go b/experimental/postgres/cmd/render.go index 2e1daf6376b..a3c6aa53344 100644 --- a/experimental/postgres/cmd/render.go +++ b/experimental/postgres/cmd/render.go @@ -36,12 +36,25 @@ func (s *textSink) Begin(fields []pgconn.FieldDescription) error { func (s *textSink) Row(values []any) error { row := make([]string, len(values)) for i, v := range values { - row[i] = textValue(v) + row[i] = escapeControlForTabwriter(textValue(v)) } s.rows = append(s.rows, row) return nil } +// escapeControlForTabwriter replaces tabs, newlines, and carriage returns in +// a cell value with the two-character backslash-letter sequence. tabwriter +// uses '\t' as a column boundary and '\n' as a row boundary, so an embedded +// tab silently shifts subsequent columns and an embedded newline splits one +// logical row into two. psql's text mode applies the same escapes. +func escapeControlForTabwriter(s string) string { + if !strings.ContainsAny(s, "\t\n\r") { + return s + } + r := strings.NewReplacer("\t", `\t`, "\n", `\n`, "\r", `\r`) + return r.Replace(s) +} + func (s *textSink) End(commandTag string) error { if len(s.columns) == 0 { _, err := fmt.Fprintln(s.out, commandTag) diff --git a/experimental/postgres/cmd/render_json.go b/experimental/postgres/cmd/render_json.go index dc713b6b786..c50739e2d1f 100644 --- a/experimental/postgres/cmd/render_json.go +++ b/experimental/postgres/cmd/render_json.go @@ -52,16 +52,28 @@ func (s *jsonSink) Begin(fields []pgconn.FieldDescription) error { s.columns = make([]string, len(fields)) s.oids = make([]uint32, len(fields)) - seen := make(map[string]int, len(fields)) + // assigned tracks every name we have committed to s.columns so far. This + // must include both first-occurrence names and __N suffixed renames, so a + // query whose original column list contains the same suffix we'd generate + // (e.g. ["id", "id__2", "id"]) does not produce two id__2 keys. + assigned := make(map[string]struct{}, len(fields)) dupes := false for i, f := range fields { s.oids[i] = f.DataTypeOID name := f.Name - seen[name]++ - if seen[name] > 1 { + if _, taken := assigned[name]; taken { dupes = true - name = fmt.Sprintf("%s__%d", f.Name, seen[name]) + suffix := 2 + for { + candidate := fmt.Sprintf("%s__%d", f.Name, suffix) + if _, taken := assigned[candidate]; !taken { + name = candidate + break + } + suffix++ + } } + assigned[name] = struct{}{} s.columns[i] = name } if dupes { diff --git a/experimental/postgres/cmd/render_json_test.go b/experimental/postgres/cmd/render_json_test.go index 4e6f474d257..9cf386cb14d 100644 --- a/experimental/postgres/cmd/render_json_test.go +++ b/experimental/postgres/cmd/render_json_test.go @@ -2,6 +2,7 @@ package postgrescmd import ( "bytes" + "strings" "testing" "github.com/jackc/pgx/v5/pgconn" @@ -136,3 +137,25 @@ func TestCommandTagParse(t *testing.T) { } } } + +func TestJSONSink_DuplicateColumns_DoesNotCollideWithExistingSuffix(t *testing.T) { + // Source columns ["id", "id__2", "id"]: the second `id` would naively + // rename to id__2, colliding with the existing id__2 from the source. + // Verify the dedup logic bumps the suffix until unique. + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(fieldsWithOIDs( + []string{"id", "id__2", "id"}, + []uint32{pgtype.Int8OID, pgtype.Int8OID, pgtype.Int8OID}, + ))) + require.NoError(t, s.Row([]any{int64(1), int64(2), int64(3)})) + require.NoError(t, s.End("SELECT 1")) + + // All three keys present with no duplicates. + out := stdout.String() + assert.Contains(t, out, `"id":1`) + assert.Contains(t, out, `"id__2":2`) + assert.Contains(t, out, `"id__3":3`) + // And NOT two id__2 keys. + assert.Equal(t, 1, strings.Count(out, `"id__2"`)) +} diff --git a/experimental/postgres/cmd/render_test.go b/experimental/postgres/cmd/render_test.go index d451febb191..bdd2bddd4f6 100644 --- a/experimental/postgres/cmd/render_test.go +++ b/experimental/postgres/cmd/render_test.go @@ -83,3 +83,16 @@ func TestTextSink_OnError_NoOp(t *testing.T) { // is never flushed. assert.Empty(t, buf.String()) } + +func TestTextSink_EscapesTabAndNewlineInCells(t *testing.T) { + var buf bytes.Buffer + s := newTextSink(&buf) + require.NoError(t, s.Begin(fields("note"))) + require.NoError(t, s.Row([]any{"a\tb\nc\rd"})) + require.NoError(t, s.End("SELECT 1")) + // The escape replaces tabs/newlines/CR with their backslash-letter forms + // so the tabwriter doesn't treat them as column or row boundaries. + assert.Contains(t, buf.String(), `a\tb\nc\rd`) + assert.NotContains(t, buf.String(), "a\tb") + assert.NotContains(t, buf.String(), "c\rd") +} diff --git a/libs/cmdio/tty.go b/libs/cmdio/tty.go index 40148bb0895..c2607b8909f 100644 --- a/libs/cmdio/tty.go +++ b/libs/cmdio/tty.go @@ -7,6 +7,16 @@ import ( "github.com/mattn/go-isatty" ) +// IsOutputTTY reports whether w is connected to a terminal. Unlike +// SupportsColor this does NOT consult NO_COLOR or TERM=dumb, which are +// colour preferences and not TTY signals. Use this when a command needs +// to decide "should I default to interactive output" or "should I +// auto-fall-back to machine-readable output on a pipe", and use +// SupportsColor only for the colour-rendering decision itself. +func IsOutputTTY(w io.Writer) bool { + return isTTY(w) +} + // isTTY detects if the given reader or writer is a terminal. func isTTY(v any) bool { // Check if it's a fakeTTY first. From a275d48117e3265545b4bf95cbeb9ae9ff281afd Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 15:20:13 +0200 Subject: [PATCH 18/25] Address PR 4 nitpicker finding: TUI fallback on tableview.Run error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit If tableview.Run errors (TUI startup failure, terminal resize race, etc.), fall through to the static tabwriter path instead of returning the error to the caller. Without the fallback, a successful query surfaces as "viewer failed" with no data — the user paid for the query but doesn't see the rows. Co-authored-by: Isaac --- experimental/postgres/cmd/render.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/experimental/postgres/cmd/render.go b/experimental/postgres/cmd/render.go index d00b4824e46..0d09556ece7 100644 --- a/experimental/postgres/cmd/render.go +++ b/experimental/postgres/cmd/render.go @@ -81,7 +81,13 @@ func (s *textSink) End(commandTag string) error { } if s.interactive && len(s.rows) > staticTableThreshold { - return tableview.Run(s.out, s.columns, s.rows) + // Try the interactive viewer; on failure (TUI startup, terminal + // resize race, etc.) fall through to the static path so the user + // still sees the rows their query returned. Without this fallback + // a successful query would surface as "viewer failed" with no data. + if err := tableview.Run(s.out, s.columns, s.rows); err == nil { + return nil + } } tw := tabwriter.NewWriter(s.out, 0, 0, 2, ' ', 0) From e81ab278c761501fb2b2058b658563bddbd3bcce Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 1 May 2026 09:01:00 +0200 Subject: [PATCH 19/25] PR 1 lint fix: drop unused provisioned helpers from internal/target This PR only uses autoscaling targeting; provisioned helpers in internal/target/provisioned.go have no caller in PR 1's net diff, which the task deadcode check (run by CI's lint job, not lint-q) correctly flags. Provisioned support lands in PR 2; the necessary subset of helpers (GetProvisioned, ProvisionedCredential) is added there alongside the first caller. Co-authored-by: Isaac --- .../cmd/internal/target/provisioned.go | 66 ------------------- 1 file changed, 66 deletions(-) delete mode 100644 experimental/postgres/cmd/internal/target/provisioned.go diff --git a/experimental/postgres/cmd/internal/target/provisioned.go b/experimental/postgres/cmd/internal/target/provisioned.go deleted file mode 100644 index 261ef37a6a8..00000000000 --- a/experimental/postgres/cmd/internal/target/provisioned.go +++ /dev/null @@ -1,66 +0,0 @@ -package target - -import ( - "context" - "errors" - "fmt" - - "github.com/databricks/databricks-sdk-go" - "github.com/databricks/databricks-sdk-go/service/database" - "github.com/google/uuid" -) - -// ListProvisionedInstances returns all provisioned database instances in the workspace. -func ListProvisionedInstances(ctx context.Context, w *databricks.WorkspaceClient) ([]database.DatabaseInstance, error) { - return w.Database.ListDatabaseInstancesAll(ctx, database.ListDatabaseInstancesRequest{}) -} - -// GetProvisioned fetches a single provisioned instance by name. -// The Name field on the response can be empty; this function ensures it is -// populated from the input so downstream callers do not have to re-set it. -func GetProvisioned(ctx context.Context, w *databricks.WorkspaceClient, name string) (*database.DatabaseInstance, error) { - instance, err := w.Database.GetDatabaseInstance(ctx, database.GetDatabaseInstanceRequest{Name: name}) - if err != nil { - return nil, fmt.Errorf("failed to get database instance: %w", err) - } - if instance.Name == "" { - instance.Name = name - } - return instance, nil -} - -// AutoSelectProvisioned returns the only provisioned instance's name (e.g. -// "my-instance"; the database SDK uses flat names, not the "projects/..." -// path shape used by autoscaling). Returns an *AmbiguousError if there are -// multiple, or a plain error if none. -func AutoSelectProvisioned(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { - instances, err := ListProvisionedInstances(ctx, w) - if err != nil { - return "", err - } - if len(instances) == 0 { - return "", errors.New("no Lakebase Provisioned instances found in workspace") - } - if len(instances) == 1 { - return instances[0].Name, nil - } - - choices := make([]Choice, 0, len(instances)) - for _, inst := range instances { - choices = append(choices, Choice{ID: inst.Name}) - } - return "", &AmbiguousError{Kind: KindInstance, FlagHint: "--target", Choices: choices} -} - -// ProvisionedCredential issues a short-lived OAuth token for the provisioned -// instance with the given name. -func ProvisionedCredential(ctx context.Context, w *databricks.WorkspaceClient, instanceName string) (string, error) { - cred, err := w.Database.GenerateDatabaseCredential(ctx, database.GenerateDatabaseCredentialRequest{ - InstanceNames: []string{instanceName}, - RequestId: uuid.NewString(), - }) - if err != nil { - return "", fmt.Errorf("failed to get database credentials: %w", err) - } - return cred.Token, nil -} From f714c237979c3f2ddb551d52a6b991e8067d617a Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 1 May 2026 09:01:55 +0200 Subject: [PATCH 20/25] PR 2 lint fix: re-add provisioned.go with only the helpers used here PR 1's lint fix dropped the entire provisioned.go because PR 1 had no caller. Re-add a slim version with just GetProvisioned and ProvisionedCredential (the two functions resolveProvisioned actually calls). Drop ListProvisionedInstances and AutoSelectProvisioned: they were originally intended for cmd/psql interactive selection, but the cmd/psql refactor was reverted, so they have no caller anywhere. Co-authored-by: Isaac --- .../cmd/internal/target/provisioned.go | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 experimental/postgres/cmd/internal/target/provisioned.go diff --git a/experimental/postgres/cmd/internal/target/provisioned.go b/experimental/postgres/cmd/internal/target/provisioned.go new file mode 100644 index 00000000000..786e86d2886 --- /dev/null +++ b/experimental/postgres/cmd/internal/target/provisioned.go @@ -0,0 +1,37 @@ +package target + +import ( + "context" + "fmt" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/database" + "github.com/google/uuid" +) + +// GetProvisioned fetches a single provisioned instance by name. +// The Name field on the response can be empty; this function ensures it is +// populated from the input so downstream callers do not have to re-set it. +func GetProvisioned(ctx context.Context, w *databricks.WorkspaceClient, name string) (*database.DatabaseInstance, error) { + instance, err := w.Database.GetDatabaseInstance(ctx, database.GetDatabaseInstanceRequest{Name: name}) + if err != nil { + return nil, fmt.Errorf("failed to get database instance: %w", err) + } + if instance.Name == "" { + instance.Name = name + } + return instance, nil +} + +// ProvisionedCredential issues a short-lived OAuth token for the provisioned +// instance with the given name. +func ProvisionedCredential(ctx context.Context, w *databricks.WorkspaceClient, instanceName string) (string, error) { + cred, err := w.Database.GenerateDatabaseCredential(ctx, database.GenerateDatabaseCredentialRequest{ + InstanceNames: []string{instanceName}, + RequestId: uuid.NewString(), + }) + if err != nil { + return "", fmt.Errorf("failed to get database credentials: %w", err) + } + return cred.Token, nil +} From 4c4433420d30c4f8faea6347e974de6009aa8a5c Mon Sep 17 00:00:00 2001 From: simon Date: Tue, 5 May 2026 11:08:11 +0200 Subject: [PATCH 21/25] Fix TLS missing in postgres query connect pgx.ParseConfig with an empty host falls back to a unix-socket path and sets TLSConfig=nil. Patching Host after the parse leaves TLSConfig nil, so the connection goes plaintext and Lakebase rejects the pgwire startup ("Invalid protocol version: 196608"). Build the DSN with the real host so pgx derives TLSConfig correctly, and keep user/password/connect-timeout as field patches. Co-authored-by: Isaac --- experimental/postgres/cmd/connect.go | 21 +++++++++++++-------- experimental/postgres/cmd/connect_test.go | 6 ++++++ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/experimental/postgres/cmd/connect.go b/experimental/postgres/cmd/connect.go index 2eefc681868..a211e19b1ce 100644 --- a/experimental/postgres/cmd/connect.go +++ b/experimental/postgres/cmd/connect.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "net/url" "time" "github.com/databricks/cli/libs/cmdio" @@ -48,20 +49,24 @@ type retryConfig struct { // is exercised by integration tests against real Lakebase endpoints. type connectFunc func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, error) -// buildPgxConfig parses a base DSN to inherit pgx's TLS shape, then patches -// in the resolved values. The DSN-then-patch pattern is the recommended way -// to configure pgx for `sslmode=require` because building a pgx.ConnConfig -// by hand omits internal fields that the parser sets. +// buildPgxConfig parses a DSN that includes the real host so pgx derives the +// right TLSConfig and Fallbacks for sslmode=require. An empty host in the DSN +// makes pgx fall back to defaultHost(), which resolves to a unix-socket path. +// pgconn classifies that as a unix socket and assigns TLSConfig=nil; patching +// cfg.Host after the parse does not re-derive TLSConfig, so the connection +// goes out in plaintext and Lakebase rejects the pgwire startup with +// "Invalid protocol version: 196608". User, password, and connect timeout are +// patched as fields because tokens can contain characters that would need +// URL-escaping in userinfo. func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) { - cfg, err := pgx.ParseConfig("postgresql:///?sslmode=require") + dsn := fmt.Sprintf("postgresql://%s:%d/%s?sslmode=require", + c.Host, c.Port, url.PathEscape(c.Database)) + cfg, err := pgx.ParseConfig(dsn) if err != nil { return nil, fmt.Errorf("parse pgx config: %w", err) } - cfg.Host = c.Host - cfg.Port = uint16(c.Port) cfg.User = c.Username cfg.Password = c.Password - cfg.Database = c.Database cfg.ConnectTimeout = c.ConnectTimeout return cfg, nil } diff --git a/experimental/postgres/cmd/connect_test.go b/experimental/postgres/cmd/connect_test.go index d58fc52cc74..fd294ef2765 100644 --- a/experimental/postgres/cmd/connect_test.go +++ b/experimental/postgres/cmd/connect_test.go @@ -146,4 +146,10 @@ func TestBuildPgxConfig(t *testing.T) { assert.Equal(t, "secret", cfg.Password) assert.Equal(t, "db", cfg.Database) assert.Equal(t, 30*time.Second, cfg.ConnectTimeout) + + // sslmode=require must produce a non-nil TLSConfig for the real host. + // Connecting in plaintext makes Lakebase reject the pgwire startup with + // "Invalid protocol version: 196608". + require.NotNil(t, cfg.TLSConfig, "TLSConfig must be set for sslmode=require") + assert.Equal(t, "host.example.com", cfg.TLSConfig.ServerName) } From a73f48484301263ddabe749fe33ff93712e1a28a Mon Sep 17 00:00:00 2001 From: simon Date: Tue, 5 May 2026 11:12:51 +0200 Subject: [PATCH 22/25] Fix multi-input column-name aliasing in bufferSink bufferSink.Begin stashed the FieldDescription slice it was handed. pgx reuses that slice's backing array across queries on the same connection (pgConn.fieldDescriptions is a fixed-size buffer that's re-sliced per statement), so each buffered unit's Fields ended up pointing at the LAST query's row description. The multi-input renderers then emitted the wrong column names for every unit but the last. Clone the slice so each buffered unit owns its column descriptions. Co-authored-by: Isaac --- experimental/postgres/cmd/result.go | 8 ++++- experimental/postgres/cmd/result_test.go | 41 ++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 experimental/postgres/cmd/result_test.go diff --git a/experimental/postgres/cmd/result.go b/experimental/postgres/cmd/result.go index 211c433ac4b..40267260af1 100644 --- a/experimental/postgres/cmd/result.go +++ b/experimental/postgres/cmd/result.go @@ -2,6 +2,7 @@ package postgrescmd import ( "context" + "slices" "time" "github.com/databricks/cli/experimental/libs/sqlcli" @@ -53,7 +54,12 @@ type bufferSink struct { } func (s *bufferSink) Begin(fields []pgconn.FieldDescription) error { - s.result.Fields = fields + // pgx reuses the FieldDescription backing array across queries on the same + // connection (pgConn.fieldDescriptions is a fixed-size buffer that's + // re-sliced per statement). Clone here so a buffered unit holds onto its + // own column descriptions; otherwise the multi-input renderers see every + // unit's Fields aliased to the last query's row description. + s.result.Fields = slices.Clone(fields) return nil } diff --git a/experimental/postgres/cmd/result_test.go b/experimental/postgres/cmd/result_test.go new file mode 100644 index 00000000000..22872d3bbd4 --- /dev/null +++ b/experimental/postgres/cmd/result_test.go @@ -0,0 +1,41 @@ +package postgrescmd + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBufferSink_BeginClonesFields(t *testing.T) { + r := &unitResult{} + s := &bufferSink{result: r} + + // pgx hands Begin a slice whose backing array gets reused for the next + // query on the same connection. Mutating the caller's slice after Begin + // must not change what the buffered result holds. + fields := []pgconn.FieldDescription{ + {Name: "first_col", DataTypeOID: 23}, + } + require.NoError(t, s.Begin(fields)) + + fields[0] = pgconn.FieldDescription{Name: "second_col", DataTypeOID: 25} + + require.Len(t, r.Fields, 1) + assert.Equal(t, "first_col", r.Fields[0].Name) + assert.Equal(t, uint32(23), r.Fields[0].DataTypeOID) +} + +func TestBufferSink_RowAndEnd(t *testing.T) { + r := &unitResult{} + s := &bufferSink{result: r} + + require.NoError(t, s.Begin([]pgconn.FieldDescription{{Name: "a"}})) + require.NoError(t, s.Row([]any{int64(1)})) + require.NoError(t, s.Row([]any{int64(2)})) + require.NoError(t, s.End("SELECT 2")) + + assert.Equal(t, [][]any{{int64(1)}, {int64(2)}}, r.Rows) + assert.Equal(t, "SELECT 2", r.CommandTag) +} From a51dc831131b920d5d2d19d73fb8c9dbb4ad593c Mon Sep 17 00:00:00 2001 From: simon Date: Tue, 5 May 2026 13:13:27 +0200 Subject: [PATCH 23/25] Use net.JoinHostPort in pgx DSN to satisfy nosprintfhostport The golangci-lint nosprintfhostport check flags fmt.Sprintf with %s:%d for host:port in URLs. Switch to net.JoinHostPort. Co-authored-by: Isaac --- experimental/postgres/cmd/connect.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/experimental/postgres/cmd/connect.go b/experimental/postgres/cmd/connect.go index a211e19b1ce..b2038efac45 100644 --- a/experimental/postgres/cmd/connect.go +++ b/experimental/postgres/cmd/connect.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/url" + "strconv" "time" "github.com/databricks/cli/libs/cmdio" @@ -59,8 +60,9 @@ type connectFunc func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, erro // patched as fields because tokens can contain characters that would need // URL-escaping in userinfo. func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) { - dsn := fmt.Sprintf("postgresql://%s:%d/%s?sslmode=require", - c.Host, c.Port, url.PathEscape(c.Database)) + dsn := fmt.Sprintf("postgresql://%s/%s?sslmode=require", + net.JoinHostPort(c.Host, strconv.Itoa(c.Port)), + url.PathEscape(c.Database)) cfg, err := pgx.ParseConfig(dsn) if err != nil { return nil, fmt.Errorf("parse pgx config: %w", err) From 6757d4e26f01e3592927005a81115f434687f0c9 Mon Sep 17 00:00:00 2001 From: simon Date: Tue, 5 May 2026 14:37:07 +0200 Subject: [PATCH 24/25] Show connecting status as a spinner that clears on success The previous "Connecting to ..." line went to stderr but stayed in the terminal forever, even after results arrived. Use cmdio.NewSpinner so the status disappears once the connection succeeds. Co-authored-by: Isaac --- experimental/postgres/cmd/query.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index fe5cc528ea7..47b3a00755f 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -100,8 +100,6 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) return err } - cmdio.LogString(ctx, fmt.Sprintf("Connecting to %s...", resolved.DisplayName)) - pgxCfg, err := buildPgxConfig(connectConfig{ Host: resolved.Host, Port: 5432, @@ -120,7 +118,13 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) MaxDelay: 10 * time.Second, } + // Spinner clears its line on Close, so the "Connecting to ..." status + // disappears once the connection is up. cmdio.NewSpinner already writes + // to stderr and degrades to a no-op in non-interactive terminals. + sp := cmdio.NewSpinner(ctx) + sp.Update("Connecting to " + resolved.DisplayName) conn, err := connectWithRetry(ctx, pgxCfg, rc, pgx.ConnectConfig) + sp.Close() if err != nil { return err } From b1948e134d6ab26252b0cbd54c4f44c7691d5616 Mon Sep 17 00:00:00 2001 From: simon Date: Tue, 5 May 2026 14:38:37 +0200 Subject: [PATCH 25/25] Sticky header in tableview TUI The header (column names + separator) was part of the scrollable viewport content, so once the viewport's YOffset moved past the first two lines, the header scrolled off and never came back, even after scrolling all the way to the top of the data. Render the header outside the viewport so it stays visible while the data scrolls. Cursor index now points into the data rows directly. Viewport height is reduced by the header height. Added tests covering the sticky-render position, dataRowCount semantics, and cursor clamping. Co-authored-by: Isaac --- libs/tableview/tableview.go | 32 ++++++++++++++-------- libs/tableview/tableview_test.go | 47 ++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/libs/tableview/tableview.go b/libs/tableview/tableview.go index 54266b72f2f..df47a29f85d 100644 --- a/libs/tableview/tableview.go +++ b/libs/tableview/tableview.go @@ -17,6 +17,7 @@ const ( footerHeight = 1 searchFooterHeight = 2 // headerLines is the number of non-data lines at the top (header + separator). + // These are rendered above the viewport so they stay visible while data scrolls. headerLines = 2 ) @@ -30,11 +31,14 @@ var ( // Run displays tabular data in an interactive browser. // Writes to w (typically stdout). Blocks until user quits. func Run(w io.Writer, columns []string, rows [][]string) error { - lines := renderTableLines(columns, rows) + all := renderTableLines(columns, rows) + header := all[:headerLines] + dataLines := all[headerLines:] m := model{ - lines: lines, - cursor: headerLines, // Start on first data row. + header: header, + lines: dataLines, + cursor: 0, } p := tea.NewProgram(m, tea.WithOutput(w)) @@ -144,20 +148,21 @@ func (m model) renderContent() string { type model struct { //nolint:recvcheck // value receivers for tea.Model interface, pointer for cursor mutation viewport viewport.Model - lines []string + header []string // sticky header lines (column names + separator) + lines []string // data rows only ready bool - cursor int // line index of the highlighted row + cursor int // index into lines (data rows) // Search state. searching bool searchInput string searchQuery string - matchLines []int + matchLines []int // indices into lines matchIdx int } func (m model) dataRowCount() int { - return max(len(m.lines)-headerLines, 0) + return len(m.lines) } func (m model) Init() tea.Cmd { @@ -171,14 +176,16 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.searching { fh = searchFooterHeight } + // Reserve room for the sticky header above the viewport. + height := msg.Height - fh - len(m.header) if !m.ready { - m.viewport = viewport.New(msg.Width, msg.Height-fh) + m.viewport = viewport.New(msg.Width, height) m.viewport.SetHorizontalStep(horizontalScrollStep) m.viewport.SetContent(m.renderContent()) m.ready = true } else { m.viewport.Width = msg.Width - m.viewport.Height = msg.Height - fh + m.viewport.Height = height } return m, nil @@ -232,7 +239,7 @@ func (m model) updateNormal(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.moveCursor(m.viewport.Height) return m, nil case "g": - m.cursor = headerLines + m.cursor = 0 m.viewport.SetContent(m.renderContent()) m.viewport.GotoTop() return m, nil @@ -252,7 +259,7 @@ func (m model) updateNormal(msg tea.KeyMsg) (tea.Model, tea.Cmd) { // moveCursor moves the cursor by delta lines, clamped to data rows. func (m *model) moveCursor(delta int) { m.cursor += delta - m.cursor = max(m.cursor, headerLines) + m.cursor = max(m.cursor, 0) m.cursor = min(m.cursor, len(m.lines)-1) m.viewport.SetContent(m.renderContent()) m.scrollToCursor() @@ -311,7 +318,8 @@ func (m model) View() string { } footer := m.renderFooter() - return m.viewport.View() + "\n" + footer + header := strings.Join(m.header, "\n") + return header + "\n" + m.viewport.View() + "\n" + footer } func (m model) renderFooter() string { diff --git a/libs/tableview/tableview_test.go b/libs/tableview/tableview_test.go index c761a9cf007..d1fd2b964c2 100644 --- a/libs/tableview/tableview_test.go +++ b/libs/tableview/tableview_test.go @@ -1,8 +1,10 @@ package tableview import ( + "strings" "testing" + "github.com/charmbracelet/bubbles/viewport" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -70,3 +72,48 @@ func TestHighlightSearchNoMatch(t *testing.T) { result := highlightSearch("hello bob", "alice") assert.Equal(t, "hello bob", result) } + +// readyModel constructs a model in the same shape Run produces, plus a viewport +// large enough that the cursor visibility logic does not need to scroll. +func readyModel(columns []string, rows [][]string, viewportHeight int) model { + all := renderTableLines(columns, rows) + m := model{ + header: all[:headerLines], + lines: all[headerLines:], + } + m.viewport = viewport.New(80, viewportHeight) + m.viewport.SetContent(m.renderContent()) + m.ready = true + return m +} + +func TestViewKeepsHeaderAboveScrollableContent(t *testing.T) { + columns := []string{"id", "name"} + rows := [][]string{{"1", "alice"}, {"2", "bob"}, {"3", "carol"}} + m := readyModel(columns, rows, 2) + + // Scroll the viewport down so the first data row falls below the top + // of the viewport. Before the sticky-header change this would also push + // the column header off-screen and never bring it back. + m.viewport.SetYOffset(1) + + out := m.View() + headerIdx := strings.Index(out, "id") + carolIdx := strings.Index(out, "carol") + require.NotEqual(t, -1, headerIdx, "View output must contain the column header") + require.NotEqual(t, -1, carolIdx, "View output must contain the visible row after scrolling") + assert.Less(t, headerIdx, carolIdx, "column header must render above the scrolled rows") +} + +func TestModelDataRowCountExcludesHeader(t *testing.T) { + m := readyModel([]string{"id"}, [][]string{{"1"}, {"2"}, {"3"}}, 5) + assert.Equal(t, 3, m.dataRowCount()) +} + +func TestMoveCursorClampsAtZeroAndLast(t *testing.T) { + m := readyModel([]string{"id"}, [][]string{{"1"}, {"2"}, {"3"}}, 5) + m.moveCursor(-100) + assert.Equal(t, 0, m.cursor, "cursor should clamp to first data row, not below") + m.moveCursor(100) + assert.Equal(t, 2, m.cursor, "cursor should clamp to last data row") +}