Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 23 additions & 32 deletions go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,12 @@ func (c *Client) ForceStop() {
c.RPC = nil
}

func (c *Client) ensureConnected() error {
func (c *Client) ensureConnected(ctx context.Context) error {
if c.client != nil {
return nil
}
if c.autoStart {
return c.Start(context.Background())
return c.Start(ctx)
}
return fmt.Errorf("client not connected. Call Start() first")
}
Expand Down Expand Up @@ -478,7 +478,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses
return nil, fmt.Errorf("an OnPermissionRequest handler is required when creating a session. For example, to allow all permissions, use &copilot.SessionConfig{OnPermissionRequest: copilot.PermissionHandler.ApproveAll}")
}

if err := c.ensureConnected(); err != nil {
if err := c.ensureConnected(ctx); err != nil {
return nil, err
}

Expand Down Expand Up @@ -575,7 +575,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string,
return nil, fmt.Errorf("an OnPermissionRequest handler is required when resuming a session. For example, to allow all permissions, use &copilot.ResumeSessionConfig{OnPermissionRequest: copilot.PermissionHandler.ApproveAll}")
}

if err := c.ensureConnected(); err != nil {
if err := c.ensureConnected(ctx); err != nil {
return nil, err
}

Expand Down Expand Up @@ -664,7 +664,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string,
//
// sessions, err := client.ListSessions(context.Background(), &SessionListFilter{Repository: "owner/repo"})
func (c *Client) ListSessions(ctx context.Context, filter *SessionListFilter) ([]SessionMetadata, error) {
if err := c.ensureConnected(); err != nil {
if err := c.ensureConnected(ctx); err != nil {
return nil, err
}

Expand Down Expand Up @@ -696,7 +696,7 @@ func (c *Client) ListSessions(ctx context.Context, filter *SessionListFilter) ([
// log.Fatal(err)
// }
func (c *Client) DeleteSession(ctx context.Context, sessionID string) error {
if err := c.ensureConnected(); err != nil {
if err := c.ensureConnected(ctx); err != nil {
return err
}

Expand Down Expand Up @@ -743,7 +743,7 @@ func (c *Client) DeleteSession(ctx context.Context, sessionID string) error {
// })
// }
func (c *Client) GetLastSessionID(ctx context.Context) (*string, error) {
if err := c.ensureConnected(); err != nil {
if err := c.ensureConnected(ctx); err != nil {
return nil, err
}

Expand Down Expand Up @@ -775,14 +775,8 @@ func (c *Client) GetLastSessionID(ctx context.Context) (*string, error) {
// fmt.Printf("TUI is displaying session: %s\n", *sessionID)
// }
func (c *Client) GetForegroundSessionID(ctx context.Context) (*string, error) {
if c.client == nil {
if c.autoStart {
if err := c.Start(ctx); err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("client not connected. Call Start() first")
}
if err := c.ensureConnected(ctx); err != nil {
return nil, err
}

result, err := c.client.Request("session.getForeground", getForegroundSessionRequest{})
Expand All @@ -809,14 +803,8 @@ func (c *Client) GetForegroundSessionID(ctx context.Context) (*string, error) {
// log.Fatal(err)
// }
func (c *Client) SetForegroundSessionID(ctx context.Context, sessionID string) error {
if c.client == nil {
if c.autoStart {
if err := c.Start(ctx); err != nil {
return err
}
} else {
return fmt.Errorf("client not connected. Call Start() first")
}
if err := c.ensureConnected(ctx); err != nil {
return err
}

result, err := c.client.Request("session.setForeground", setForegroundSessionRequest{SessionID: sessionID})
Expand Down Expand Up @@ -1123,7 +1111,7 @@ func (c *Client) startCLIServer(ctx context.Context) error {
args = append([]string{cliPath}, args...)
}

c.process = exec.CommandContext(ctx, command, args...)
c.process = exec.Command(command, args...)

// Configure platform-specific process attributes (e.g., hide window on Windows)
configureProcAttr(c.process)
Expand Down Expand Up @@ -1179,14 +1167,16 @@ func (c *Client) startCLIServer(ctx context.Context) error {
c.monitorProcess()

scanner := bufio.NewScanner(stdout)
timeout := time.After(10 * time.Second)
portRegex := regexp.MustCompile(`listening on port (\d+)`)

ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()

for {
select {
case <-timeout:
case <-ctx.Done():
killErr := c.killProcess()
return errors.Join(errors.New("timeout waiting for CLI server to start"), killErr)
return errors.Join(fmt.Errorf("failed waiting for CLI server to start: %w", ctx.Err()), killErr)
case <-c.processDone:
killErr := c.killProcess()
return errors.Join(errors.New("CLI server process exited before reporting port"), killErr)
Expand Down Expand Up @@ -1258,12 +1248,13 @@ func (c *Client) connectViaTcp(ctx context.Context) error {
return fmt.Errorf("server port not available")
}

// Create TCP connection that cancels on context done or after 10 seconds
// Merge a 10-second timeout with the caller's context so whichever
// deadline comes first wins.
address := net.JoinHostPort(c.actualHost, fmt.Sprintf("%d", c.actualPort))
dialer := net.Dialer{
Timeout: 10 * time.Second,
}
conn, err := dialer.DialContext(ctx, "tcp", address)
dialCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
var dialer net.Dialer
conn, err := dialer.DialContext(dialCtx, "tcp", address)
if err != nil {
return fmt.Errorf("failed to connect to CLI server at %s: %w", address, err)
}
Expand Down
27 changes: 27 additions & 0 deletions go/client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package copilot

import (
"context"
"encoding/json"
"os"
"path/filepath"
Expand Down Expand Up @@ -529,6 +530,32 @@ func TestClient_ResumeSession_RequiresPermissionHandler(t *testing.T) {
})
}

func TestClient_StartContextCancellationDoesNotKillProcess(t *testing.T) {
cliPath := findCLIPathForTest()
if cliPath == "" {
t.Skip("CLI not found")
}

client := NewClient(&ClientOptions{CLIPath: cliPath})
t.Cleanup(func() { client.ForceStop() })

// Start with a context, then cancel it after the client is connected.
ctx, cancel := context.WithCancel(t.Context())
if err := client.Start(ctx); err != nil {
t.Fatalf("Start failed: %v", err)
}
cancel() // cancel the context that was used for Start

// The CLI process should still be alive and responsive.
resp, err := client.Ping(t.Context(), "still alive")
if err != nil {
t.Fatalf("Ping after context cancellation failed: %v", err)
}
if resp == nil {
t.Fatal("expected non-nil ping response")
}
}

func TestClient_StartStopRace(t *testing.T) {
cliPath := findCLIPathForTest()
if cliPath == "" {
Expand Down
Loading