diff --git a/go/client.go b/go/client.go index a37040f2..f588bf8f 100644 --- a/go/client.go +++ b/go/client.go @@ -434,12 +434,12 @@ func (c *Client) ForceStop() { c.RPC = nil } -func (c *Client) ensureConnected() error { +func (c *Client) ensureConnected(ctx context.Context) error { if c.client != nil { return nil } if c.autoStart { - return c.Start(context.Background()) + return c.Start(ctx) } return fmt.Errorf("client not connected. Call Start() first") } @@ -478,7 +478,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses return nil, fmt.Errorf("an OnPermissionRequest handler is required when creating a session. For example, to allow all permissions, use &copilot.SessionConfig{OnPermissionRequest: copilot.PermissionHandler.ApproveAll}") } - if err := c.ensureConnected(); err != nil { + if err := c.ensureConnected(ctx); err != nil { return nil, err } @@ -575,7 +575,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, return nil, fmt.Errorf("an OnPermissionRequest handler is required when resuming a session. For example, to allow all permissions, use &copilot.ResumeSessionConfig{OnPermissionRequest: copilot.PermissionHandler.ApproveAll}") } - if err := c.ensureConnected(); err != nil { + if err := c.ensureConnected(ctx); err != nil { return nil, err } @@ -664,7 +664,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, // // sessions, err := client.ListSessions(context.Background(), &SessionListFilter{Repository: "owner/repo"}) func (c *Client) ListSessions(ctx context.Context, filter *SessionListFilter) ([]SessionMetadata, error) { - if err := c.ensureConnected(); err != nil { + if err := c.ensureConnected(ctx); err != nil { return nil, err } @@ -696,7 +696,7 @@ func (c *Client) ListSessions(ctx context.Context, filter *SessionListFilter) ([ // log.Fatal(err) // } func (c *Client) DeleteSession(ctx context.Context, sessionID string) error { - if err := c.ensureConnected(); err != nil { + if err := c.ensureConnected(ctx); err != nil { return err } @@ -743,7 +743,7 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error { // }) // } func (c *Client) GetLastSessionID(ctx context.Context) (*string, error) { - if err := c.ensureConnected(); err != nil { + if err := c.ensureConnected(ctx); err != nil { return nil, err } @@ -775,14 +775,8 @@ func (c *Client) GetLastSessionID(ctx context.Context) (*string, error) { // fmt.Printf("TUI is displaying session: %s\n", *sessionID) // } func (c *Client) GetForegroundSessionID(ctx context.Context) (*string, error) { - if c.client == nil { - if c.autoStart { - if err := c.Start(ctx); err != nil { - return nil, err - } - } else { - return nil, fmt.Errorf("client not connected. Call Start() first") - } + if err := c.ensureConnected(ctx); err != nil { + return nil, err } result, err := c.client.Request("session.getForeground", getForegroundSessionRequest{}) @@ -809,14 +803,8 @@ func (c *Client) GetForegroundSessionID(ctx context.Context) (*string, error) { // log.Fatal(err) // } func (c *Client) SetForegroundSessionID(ctx context.Context, sessionID string) error { - if c.client == nil { - if c.autoStart { - if err := c.Start(ctx); err != nil { - return err - } - } else { - return fmt.Errorf("client not connected. Call Start() first") - } + if err := c.ensureConnected(ctx); err != nil { + return err } result, err := c.client.Request("session.setForeground", setForegroundSessionRequest{SessionID: sessionID}) @@ -1123,7 +1111,7 @@ func (c *Client) startCLIServer(ctx context.Context) error { args = append([]string{cliPath}, args...) } - c.process = exec.CommandContext(ctx, command, args...) + c.process = exec.Command(command, args...) // Configure platform-specific process attributes (e.g., hide window on Windows) configureProcAttr(c.process) @@ -1179,14 +1167,16 @@ func (c *Client) startCLIServer(ctx context.Context) error { c.monitorProcess() scanner := bufio.NewScanner(stdout) - timeout := time.After(10 * time.Second) portRegex := regexp.MustCompile(`listening on port (\d+)`) + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + for { select { - case <-timeout: + case <-ctx.Done(): killErr := c.killProcess() - return errors.Join(errors.New("timeout waiting for CLI server to start"), killErr) + return errors.Join(fmt.Errorf("failed waiting for CLI server to start: %w", ctx.Err()), killErr) case <-c.processDone: killErr := c.killProcess() return errors.Join(errors.New("CLI server process exited before reporting port"), killErr) @@ -1258,12 +1248,13 @@ func (c *Client) connectViaTcp(ctx context.Context) error { return fmt.Errorf("server port not available") } - // Create TCP connection that cancels on context done or after 10 seconds + // Merge a 10-second timeout with the caller's context so whichever + // deadline comes first wins. address := net.JoinHostPort(c.actualHost, fmt.Sprintf("%d", c.actualPort)) - dialer := net.Dialer{ - Timeout: 10 * time.Second, - } - conn, err := dialer.DialContext(ctx, "tcp", address) + dialCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + var dialer net.Dialer + conn, err := dialer.DialContext(dialCtx, "tcp", address) if err != nil { return fmt.Errorf("failed to connect to CLI server at %s: %w", address, err) } diff --git a/go/client_test.go b/go/client_test.go index d791a5a3..d707418b 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -1,6 +1,7 @@ package copilot import ( + "context" "encoding/json" "os" "path/filepath" @@ -529,6 +530,32 @@ func TestClient_ResumeSession_RequiresPermissionHandler(t *testing.T) { }) } +func TestClient_StartContextCancellationDoesNotKillProcess(t *testing.T) { + cliPath := findCLIPathForTest() + if cliPath == "" { + t.Skip("CLI not found") + } + + client := NewClient(&ClientOptions{CLIPath: cliPath}) + t.Cleanup(func() { client.ForceStop() }) + + // Start with a context, then cancel it after the client is connected. + ctx, cancel := context.WithCancel(t.Context()) + if err := client.Start(ctx); err != nil { + t.Fatalf("Start failed: %v", err) + } + cancel() // cancel the context that was used for Start + + // The CLI process should still be alive and responsive. + resp, err := client.Ping(t.Context(), "still alive") + if err != nil { + t.Fatalf("Ping after context cancellation failed: %v", err) + } + if resp == nil { + t.Fatal("expected non-nil ping response") + } +} + func TestClient_StartStopRace(t *testing.T) { cliPath := findCLIPathForTest() if cliPath == "" {