diff --git a/experimental/postgres/cmd/cancel_test.go b/experimental/postgres/cmd/cancel_test.go new file mode 100644 index 0000000000..4245b905ef --- /dev/null +++ b/experimental/postgres/cmd/cancel_test.go @@ -0,0 +1,89 @@ +package postgrescmd + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestWithStatementTimeout_ZeroIsPassthrough(t *testing.T) { + parent := t.Context() + got, cancel := withStatementTimeout(parent, 0) + defer cancel() + // Parent and got should compare equal: zero timeout returns the parent + // unchanged (and a no-op cancel). + deadline, ok := got.Deadline() + assert.False(t, ok, "deadline should not be set when timeout is 0") + assert.True(t, deadline.IsZero()) +} + +func TestWithStatementTimeout_AppliesDeadline(t *testing.T) { + parent := t.Context() + got, cancel := withStatementTimeout(parent, time.Second) + defer cancel() + deadline, ok := got.Deadline() + assert.True(t, ok) + assert.False(t, deadline.IsZero()) +} + +func TestReportCancellation_SignalCanceled(t *testing.T) { + signalCtx, signalCancel := context.WithCancel(t.Context()) + signalCancel() + stmtCtx := signalCtx + msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, errors.New("anything"), 0) + assert.Equal(t, "Query cancelled.", msg) + assert.True(t, invocationScoped) +} + +func TestReportCancellation_TimeoutFired(t *testing.T) { + signalCtx := t.Context() + stmtCtx, stmtCancel := context.WithDeadline(signalCtx, time.Now().Add(-time.Second)) + defer stmtCancel() + <-stmtCtx.Done() + msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, errors.New("query failed"), 5*time.Second) + assert.Equal(t, "Query timed out after 5s.", msg) + assert.True(t, invocationScoped) +} + +func TestReportCancellation_GenericError(t *testing.T) { + signalCtx := t.Context() + stmtCtx := signalCtx + msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, errors.New("syntax error"), 0) + assert.Equal(t, "syntax error", msg) + assert.False(t, invocationScoped) +} + +func TestReportCancellation_BothFire_CancelWinsRace(t *testing.T) { + // User cancel and deadline both already done. Precedence: cancel wins + // (the user's intent dominates a coincidental deadline). A future + // reordering of the switch would silently flip this; the test pins it. + signalCtx, signalCancel := context.WithCancel(t.Context()) + signalCancel() + stmtCtx, stmtCancel := context.WithDeadline(signalCtx, time.Now().Add(-time.Second)) + defer stmtCancel() + <-stmtCtx.Done() + msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, errors.New("anything"), time.Second) + assert.Equal(t, "Query cancelled.", msg) + assert.True(t, invocationScoped) +} + +func TestWatchInterruptSignals_CancelOnStop(t *testing.T) { + // stop should cancel the parent context as a side-effect so the goroutine + // terminates promptly. We don't actually send a SIGINT here (it would + // also kill the test runner); we just verify stop cleans up. + parent, parentCancel := context.WithCancel(t.Context()) + defer parentCancel() + + cancelled := false + cancel := func() { + cancelled = true + parentCancel() + } + + stop := watchInterruptSignals(parent, cancel) + stop() + assert.True(t, cancelled, "stop should call cancel to wake the goroutine") +} diff --git a/experimental/postgres/cmd/connect.go b/experimental/postgres/cmd/connect.go index b2038efac4..0cc47c998a 100644 --- a/experimental/postgres/cmd/connect.go +++ b/experimental/postgres/cmd/connect.go @@ -13,6 +13,7 @@ import ( "github.com/databricks/cli/libs/log" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" ) // defaultConnectTimeout is the dial timeout for a single connect attempt. @@ -59,6 +60,19 @@ type connectFunc func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, erro // "Invalid protocol version: 196608". User, password, and connect timeout are // patched as fields because tokens can contain characters that would need // URL-escaping in userinfo. +// +// The context-watcher handler is overridden so context cancellation issues +// a Postgres CancelRequest on the side-channel rather than only closing the +// underlying TCP connection. Without this override, a Ctrl+C during a long +// SELECT would tear down the TCP socket but leave the server-side query +// running until it noticed the broken connection on its next write. +// +// CancelRequestDelay = 0: send the cancel-request immediately on ctx cancel. +// The user just hit Ctrl+C; we want the server to learn now. +// DeadlineDelay = 5s: if the cancel-request has not gotten the server to +// terminate the query within 5s, fall back to deadlining the connection. +// Zero DeadlineDelay would race the cancel-request and could leave the +// connection unusable. func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) { dsn := fmt.Sprintf("postgresql://%s/%s?sslmode=require", net.JoinHostPort(c.Host, strconv.Itoa(c.Port)), @@ -70,6 +84,14 @@ func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) { cfg.User = c.Username cfg.Password = c.Password cfg.ConnectTimeout = c.ConnectTimeout + + cfg.BuildContextWatcherHandler = func(pgc *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: pgc, + CancelRequestDelay: 0, + DeadlineDelay: 5 * time.Second, + } + } return cfg, nil } diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index 7ca86ca131..ca28d6e84a 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -27,6 +27,7 @@ type queryFlags struct { connectTimeout time.Duration maxRetries int files []string + timeout time.Duration // outputFormat is the raw flag value. resolveOutputFormat turns it into // the effective format (which may differ when stdout is piped). @@ -96,6 +97,7 @@ Limitations (this release): cmd.Flags().StringVarP(&f.database, "database", "d", defaultDatabase, "Database name") cmd.Flags().DurationVar(&f.connectTimeout, "connect-timeout", defaultConnectTimeout, "Connect timeout") cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (must be >= 1; 1 disables retry)") + cmd.Flags().DurationVar(&f.timeout, "timeout", 0, "Per-statement timeout (0 disables)") cmd.Flags().StringArrayVarP(&f.files, "file", "f", nil, "SQL file path (repeatable)") cmd.Flags().StringVarP(&f.outputFormat, "output", "o", string(sqlcli.OutputText), "Output format: text, json, or csv") cmd.RegisterFlagCompletionFunc("output", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { @@ -176,16 +178,27 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla MaxDelay: 10 * time.Second, } + // Invocation-scoped context: cancelled by Ctrl+C/SIGTERM. Owns the + // connection lifecycle. Per-statement timeouts are children of this so + // a cancelled invocation also cancels the in-flight statement. + signalCtx, signalCancel := context.WithCancel(ctx) + defer signalCancel() + + stopSignals := watchInterruptSignals(signalCtx, signalCancel) + defer stopSignals() + // Spinner clears its line on Close, so the "Connecting to ..." status // disappears once the connection is up. cmdio.NewSpinner already writes // to stderr and degrades to a no-op in non-interactive terminals. - sp := cmdio.NewSpinner(ctx) + sp := cmdio.NewSpinner(signalCtx) sp.Update("Connecting to " + resolved.DisplayName) - conn, err := connectWithRetry(ctx, pgxCfg, rc, pgx.ConnectConfig) + conn, err := connectWithRetry(signalCtx, pgxCfg, rc, pgx.ConnectConfig) sp.Close() if err != nil { return err } + // Close on a background ctx so a cancelled signalCtx does not abort a + // clean teardown handshake. defer conn.Close(context.WithoutCancel(ctx)) out := cmd.OutOrStdout() @@ -196,9 +209,16 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla // Avoids buffering rows for large exports and matches the v1 single- // input behaviour PR 2 shipped. Wrap the error so DETAIL / HINT // from a *pgconn.PgError surface even on the single-input path. - sink := newSink(format, out, stderr) - if err := executeOne(ctx, conn, units[0].SQL, sink); err != nil { - return errors.New(formatPgError(err)) + // Promote-to-interactive only when stdout is a prompt-capable TTY so + // a pipe falls back to the static table rather than launching a TUI + // into a dead writer. + sink := newSinkInteractive(format, out, stderr, stdoutTTY && cmdio.IsPromptSupported(ctx)) + stmtCtx, stmtCancel := withStatementTimeout(signalCtx, f.timeout) + err := executeOne(stmtCtx, conn, units[0].SQL, sink) + stmtCancel() + if err != nil { + msg, _ := reportCancellation(signalCtx, stmtCtx, err, f.timeout) + return errors.New(msg) } return nil } @@ -209,7 +229,9 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla // temp tables) carries across units because we hold the same connection. results := make([]*unitResult, 0, len(units)) for _, u := range units { - r, err := runUnitBuffered(ctx, conn, u) + stmtCtx, stmtCancel := withStatementTimeout(signalCtx, f.timeout) + r, err := runUnitBuffered(stmtCtx, conn, u) + stmtCancel() if err != nil { // Render the successful prefix, then surface the error with // rich pgError formatting if applicable. @@ -218,7 +240,14 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla // error to the user, the renderer error to debug logs. fmt.Fprintln(stderr, "warning: failed to render partial result:", rerr) } - return formatExecutionError(u.Source, err) + msg, invocationScoped := reportCancellation(signalCtx, stmtCtx, err, f.timeout) + if invocationScoped { + // User cancel / timeout is invocation-scoped; the source + // prefix is redundant ("--file foo.sql: Query cancelled." + // reads worse than just "Query cancelled."). + return errors.New(msg) + } + return errors.New(u.Source + ": " + msg) } results = append(results, r) } @@ -231,15 +260,51 @@ func runQuery(ctx context.Context, cmd *cobra.Command, args []string, f queryFla } } -// newSink returns the rowSink for the chosen output format. Kept separate -// from runQuery so tests can build sinks without going through pgx. -func newSink(format sqlcli.Format, out, stderr io.Writer) rowSink { +// withStatementTimeout returns ctx unchanged (and a no-op cancel) when +// timeout is zero, otherwise a child context with the timeout applied. Each +// statement gets its own deadline so cancellation is scoped to one +// statement at a time. +func withStatementTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + if timeout <= 0 { + return parent, func() {} + } + return context.WithTimeout(parent, timeout) +} + +// reportCancellation distinguishes the three error cases that look the same +// from `executeOne`'s POV (a wrapped pgconn / network error): user cancelled +// via Ctrl+C, --timeout fired, or the statement just plain errored. Returns +// the human-readable message and whether the cause is invocation-scoped +// (cancel/timeout) rather than statement-scoped. +// +// Precedence: user cancel beats deadline. If both contexts fire near- +// simultaneously (race), we report "cancelled" because the user's intent +// dominates a coincidental timeout. +func reportCancellation(signalCtx, stmtCtx context.Context, err error, timeout time.Duration) (msg string, invocationScoped bool) { + switch { + case errors.Is(signalCtx.Err(), context.Canceled): + return "Query cancelled.", true + case timeout > 0 && errors.Is(stmtCtx.Err(), context.DeadlineExceeded): + return fmt.Sprintf("Query timed out after %s.", timeout), true + default: + return formatPgError(err), false + } +} + +// newSinkInteractive returns the rowSink for the chosen output format. When +// interactive is true the text sink may launch the libs/tableview viewer for +// results larger than staticTableThreshold; when false it uses the static +// tabwriter table. +func newSinkInteractive(format sqlcli.Format, out, stderr io.Writer, interactive bool) rowSink { switch format { case sqlcli.OutputJSON: return newJSONSink(out, stderr) case sqlcli.OutputCSV: return newCSVSink(out, stderr) default: + if interactive { + return newInteractiveTextSink(out) + } return newTextSink(out) } } @@ -258,13 +323,6 @@ func renderPartial(out, stderr io.Writer, format sqlcli.Format, results []*unitR } } -// formatExecutionError produces the error returned to cobra when an input -// unit failed. The message includes the source label so the user knows -// which of N inputs blew up. -func formatExecutionError(source string, err error) error { - return fmt.Errorf("%s: %s", source, formatPgError(err)) -} - // multiStatementHint is appended to errMultipleStatements so users see the // recovery path inline. const multiStatementHint = "\nThis command runs one statement per input. To run multiple statements:\n" + diff --git a/experimental/postgres/cmd/render.go b/experimental/postgres/cmd/render.go index a3c6aa5334..0d09556ece 100644 --- a/experimental/postgres/cmd/render.go +++ b/experimental/postgres/cmd/render.go @@ -6,25 +6,44 @@ import ( "strings" "text/tabwriter" + "github.com/databricks/cli/libs/tableview" "github.com/jackc/pgx/v5/pgconn" ) +// staticTableThreshold is the row count above which we hand off to +// libs/tableview's interactive viewer (when stdout is interactive). Smaller +// results stay in the static tabwriter path so they stream to a pipe +// unchanged. Matches the threshold aitools query uses. +const staticTableThreshold = 30 + // textSink renders results as plain text: a tabwriter-aligned table for // rows-producing statements, the command tag for command-only ones. // // Text output buffers all rows because tabwriter needs the widest cell in each // column before it can align. Streaming output is provided by the JSON and CSV // sinks; users with huge result sets should pick those. +// +// When interactive is true and the result has more than staticTableThreshold +// rows, End hands off to libs/tableview's scrollable viewer instead of +// emitting the static table. The interactive path requires a real TTY and a +// prompt-capable terminal; the caller decides. type textSink struct { - out io.Writer - columns []string - rows [][]string + out io.Writer + interactive bool + columns []string + rows [][]string } func newTextSink(out io.Writer) *textSink { return &textSink{out: out} } +// newInteractiveTextSink returns a text sink that uses the interactive table +// viewer for results larger than staticTableThreshold. +func newInteractiveTextSink(out io.Writer) *textSink { + return &textSink{out: out, interactive: true} +} + func (s *textSink) Begin(fields []pgconn.FieldDescription) error { s.columns = make([]string, len(fields)) for i, f := range fields { @@ -61,6 +80,16 @@ func (s *textSink) End(commandTag string) error { return err } + if s.interactive && len(s.rows) > staticTableThreshold { + // Try the interactive viewer; on failure (TUI startup, terminal + // resize race, etc.) fall through to the static path so the user + // still sees the rows their query returned. Without this fallback + // a successful query would surface as "viewer failed" with no data. + if err := tableview.Run(s.out, s.columns, s.rows); err == nil { + return nil + } + } + tw := tabwriter.NewWriter(s.out, 0, 0, 2, ' ', 0) fmt.Fprintln(tw, strings.Join(s.columns, "\t")) fmt.Fprintln(tw, strings.Join(headerSeparator(s.columns), "\t")) diff --git a/experimental/postgres/cmd/signals.go b/experimental/postgres/cmd/signals.go new file mode 100644 index 0000000000..5e4c29346f --- /dev/null +++ b/experimental/postgres/cmd/signals.go @@ -0,0 +1,40 @@ +package postgrescmd + +import ( + "context" + "os" + "os/signal" + "syscall" +) + +// watchInterruptSignals installs handlers for SIGINT and SIGTERM that call +// cancel when the user hits Ctrl+C or the process gets a SIGTERM. +// +// Returns a stop-and-cancel function that uninstalls the handlers (signal.Stop +// prevents future OS deliveries) and cancels the parent context so the +// goroutine wakes promptly. The caller must defer it. The channel is +// 1-buffered and GC'd on return; no explicit drain is needed. +// +// On Windows, Go maps Ctrl+C to os.Interrupt via the console-control-handler, +// so the same code path covers Windows. +func watchInterruptSignals(ctx context.Context, cancel context.CancelFunc) func() { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + + done := make(chan struct{}) + go func() { + select { + case <-sigCh: + cancel() + case <-ctx.Done(): + } + close(done) + }() + + return func() { + signal.Stop(sigCh) + // Wake the goroutine in case neither sigCh nor ctx.Done has fired. + cancel() + <-done + } +} diff --git a/libs/tableview/tableview.go b/libs/tableview/tableview.go index 54266b72f2..df47a29f85 100644 --- a/libs/tableview/tableview.go +++ b/libs/tableview/tableview.go @@ -17,6 +17,7 @@ const ( footerHeight = 1 searchFooterHeight = 2 // headerLines is the number of non-data lines at the top (header + separator). + // These are rendered above the viewport so they stay visible while data scrolls. headerLines = 2 ) @@ -30,11 +31,14 @@ var ( // Run displays tabular data in an interactive browser. // Writes to w (typically stdout). Blocks until user quits. func Run(w io.Writer, columns []string, rows [][]string) error { - lines := renderTableLines(columns, rows) + all := renderTableLines(columns, rows) + header := all[:headerLines] + dataLines := all[headerLines:] m := model{ - lines: lines, - cursor: headerLines, // Start on first data row. + header: header, + lines: dataLines, + cursor: 0, } p := tea.NewProgram(m, tea.WithOutput(w)) @@ -144,20 +148,21 @@ func (m model) renderContent() string { type model struct { //nolint:recvcheck // value receivers for tea.Model interface, pointer for cursor mutation viewport viewport.Model - lines []string + header []string // sticky header lines (column names + separator) + lines []string // data rows only ready bool - cursor int // line index of the highlighted row + cursor int // index into lines (data rows) // Search state. searching bool searchInput string searchQuery string - matchLines []int + matchLines []int // indices into lines matchIdx int } func (m model) dataRowCount() int { - return max(len(m.lines)-headerLines, 0) + return len(m.lines) } func (m model) Init() tea.Cmd { @@ -171,14 +176,16 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.searching { fh = searchFooterHeight } + // Reserve room for the sticky header above the viewport. + height := msg.Height - fh - len(m.header) if !m.ready { - m.viewport = viewport.New(msg.Width, msg.Height-fh) + m.viewport = viewport.New(msg.Width, height) m.viewport.SetHorizontalStep(horizontalScrollStep) m.viewport.SetContent(m.renderContent()) m.ready = true } else { m.viewport.Width = msg.Width - m.viewport.Height = msg.Height - fh + m.viewport.Height = height } return m, nil @@ -232,7 +239,7 @@ func (m model) updateNormal(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.moveCursor(m.viewport.Height) return m, nil case "g": - m.cursor = headerLines + m.cursor = 0 m.viewport.SetContent(m.renderContent()) m.viewport.GotoTop() return m, nil @@ -252,7 +259,7 @@ func (m model) updateNormal(msg tea.KeyMsg) (tea.Model, tea.Cmd) { // moveCursor moves the cursor by delta lines, clamped to data rows. func (m *model) moveCursor(delta int) { m.cursor += delta - m.cursor = max(m.cursor, headerLines) + m.cursor = max(m.cursor, 0) m.cursor = min(m.cursor, len(m.lines)-1) m.viewport.SetContent(m.renderContent()) m.scrollToCursor() @@ -311,7 +318,8 @@ func (m model) View() string { } footer := m.renderFooter() - return m.viewport.View() + "\n" + footer + header := strings.Join(m.header, "\n") + return header + "\n" + m.viewport.View() + "\n" + footer } func (m model) renderFooter() string { diff --git a/libs/tableview/tableview_test.go b/libs/tableview/tableview_test.go index c761a9cf00..d1fd2b964c 100644 --- a/libs/tableview/tableview_test.go +++ b/libs/tableview/tableview_test.go @@ -1,8 +1,10 @@ package tableview import ( + "strings" "testing" + "github.com/charmbracelet/bubbles/viewport" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -70,3 +72,48 @@ func TestHighlightSearchNoMatch(t *testing.T) { result := highlightSearch("hello bob", "alice") assert.Equal(t, "hello bob", result) } + +// readyModel constructs a model in the same shape Run produces, plus a viewport +// large enough that the cursor visibility logic does not need to scroll. +func readyModel(columns []string, rows [][]string, viewportHeight int) model { + all := renderTableLines(columns, rows) + m := model{ + header: all[:headerLines], + lines: all[headerLines:], + } + m.viewport = viewport.New(80, viewportHeight) + m.viewport.SetContent(m.renderContent()) + m.ready = true + return m +} + +func TestViewKeepsHeaderAboveScrollableContent(t *testing.T) { + columns := []string{"id", "name"} + rows := [][]string{{"1", "alice"}, {"2", "bob"}, {"3", "carol"}} + m := readyModel(columns, rows, 2) + + // Scroll the viewport down so the first data row falls below the top + // of the viewport. Before the sticky-header change this would also push + // the column header off-screen and never bring it back. + m.viewport.SetYOffset(1) + + out := m.View() + headerIdx := strings.Index(out, "id") + carolIdx := strings.Index(out, "carol") + require.NotEqual(t, -1, headerIdx, "View output must contain the column header") + require.NotEqual(t, -1, carolIdx, "View output must contain the visible row after scrolling") + assert.Less(t, headerIdx, carolIdx, "column header must render above the scrolled rows") +} + +func TestModelDataRowCountExcludesHeader(t *testing.T) { + m := readyModel([]string{"id"}, [][]string{{"1"}, {"2"}, {"3"}}, 5) + assert.Equal(t, 3, m.dataRowCount()) +} + +func TestMoveCursorClampsAtZeroAndLast(t *testing.T) { + m := readyModel([]string{"id"}, [][]string{{"1"}, {"2"}, {"3"}}, 5) + m.moveCursor(-100) + assert.Equal(t, 0, m.cursor, "cursor should clamp to first data row, not below") + m.moveCursor(100) + assert.Equal(t, 2, m.cursor, "cursor should clamp to last data row") +}