diff --git a/mcp/streamable.go b/mcp/streamable.go index 76f2b0e4..f4dd21f3 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1801,6 +1801,22 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e if (resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden) && c.oauthHandler != nil { if err := c.oauthHandler.Authorize(ctx, req, resp); err != nil { + // If the caller's context was cancelled while we were running the + // authorization flow, treat the connection as failed so subsequent + // operations on it (e.g. the cancellation notify the call layer + // sends in response to ctx cancellation) short-circuit instead of + // re-invoking the OAuth handler. Otherwise the user gets prompted + // to authorize a request they have already abandoned. See #882. + // + // We check ctx.Err() rather than the error returned by Authorize, + // because the handler is user-implemented and may return an error + // that does not wrap context.Canceled (e.g. a custom sentinel or + // a fmt.Errorf with %v). The context itself is the authoritative + // source for whether the caller abandoned the request. + ctxErr := ctx.Err() + if errors.Is(ctxErr, context.Canceled) || errors.Is(ctxErr, context.DeadlineExceeded) { + c.fail(fmt.Errorf("%s: authorization cancelled: %w", requestSummary, err)) + } // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr // and permanently break the connection. // Wrap the authorization error as well for client inspection. diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 43564fd3..719cb580 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -1016,6 +1016,94 @@ func TestStreamableClientOAuth_401(t *testing.T) { } } +// blockingCountingOAuthHandler is an OAuthHandler that blocks inside +// Authorize until the caller's context is cancelled, then returns a custom +// error that does NOT wrap context.Canceled. This mirrors real-world OAuth +// handlers that catch the cancellation internally and surface their own +// error type. The fix for #882 checks ctx.Err() directly rather than +// relying on the error from Authorize, so this must still trigger c.fail(). +// It records how many times Authorize is invoked. +type blockingCountingOAuthHandler struct { + mu sync.Mutex + callCount int +} + +func (h *blockingCountingOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return nil, nil +} + +func (h *blockingCountingOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + h.mu.Lock() + h.callCount++ + h.mu.Unlock() + // Block until the caller's context is cancelled, mirroring an + // interactive OAuth flow that the user has abandoned. + <-ctx.Done() + // Return a custom error that does not wrap context.Canceled, as a + // real-world handler might. The code under test must check ctx.Err() + // to detect the cancellation, not this error. + return fmt.Errorf("oauth flow interrupted") +} + +func (h *blockingCountingOAuthHandler) Calls() int { + h.mu.Lock() + defer h.mu.Unlock() + return h.callCount +} + +// TestStreamableClientOAuth_CancelledAuthorize_NoReprompt is a regression +// test for #882. When OAuthHandler.Authorize returns a context-cancelled +// error, the connection must enter a failed state so that the cancellation +// notification the call layer sends in response to ctx cancellation does +// not flow back through the same broken auth path and re-invoke Authorize. +func TestStreamableClientOAuth_CancelledAuthorize_NoReprompt(t *testing.T) { + handler := &blockingCountingOAuthHandler{} + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + }, + } + verifier := func(ctx context.Context, token string, req *http.Request) (*auth.TokenInfo, error) { + return &auth.TokenInfo{Expiration: time.Now().Add(time.Hour)}, nil + } + httpServer := httptest.NewServer(auth.RequireBearerToken(verifier, nil)(fake)) + t.Cleanup(httpServer.Close) + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: handler, + } + client := NewClient(testImpl, nil) + + // Use a context with a tight deadline so the cancellation path runs + // while the auth flow is in progress. + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + _, err := client.Connect(ctx, transport, nil) + if err == nil { + t.Fatal("expected client.Connect to fail") + } + + // Give the cancellation Notify path a moment to (try to) run. + time.Sleep(50 * time.Millisecond) + + // Authorize should be invoked exactly once. The bug in #882 caused + // it to be invoked a second time when the call layer sent the + // cancellation notification through the same auth-broken connection. + if got := handler.Calls(); got != 1 { + t.Errorf("expected Authorize to be called exactly 1 time, got %d", got) + } +} + func TestTokenInfo(t *testing.T) { ctx := context.Background() diff --git a/mcp/transport.go b/mcp/transport.go index 23dccf8e..426aa07e 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -14,13 +14,20 @@ import ( "net" "os" "sync" + "time" internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" - "github.com/modelcontextprotocol/go-sdk/internal/xcontext" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) +// notifyCancellationTimeout bounds the cancellation notification we send to +// the peer when the caller's context is cancelled. The notification is +// best-effort: a degraded connection (e.g. an OAuth flow that has been +// abandoned) must not be able to block the caller's return path or +// re-trigger expensive recovery on its behalf. See issue #882. +const notifyCancellationTimeout = 5 * time.Second + // ErrConnectionClosed is returned when sending a message to a connection that // is closed or in the process of closing. var ErrConnectionClosed = errors.New("connection closed") @@ -216,8 +223,19 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params case errors.Is(err, jsonrpc2.ErrClientClosing), errors.Is(err, jsonrpc2.ErrServerClosing): return fmt.Errorf("%w: calling %q: %v", ErrConnectionClosed, method, err) case ctx.Err() != nil: - // Notify the peer of cancellation. - err := conn.Notify(xcontext.Detach(ctx), notificationCancelled, &CancelledParams{ + // Best-effort notify the peer of cancellation. We deliberately bound + // this with a fresh, short-lived context derived from + // context.Background() rather than reusing (or merely detaching) the + // caller's already-cancelled context. The connection may be in a + // degraded state — for example, the original failure may have come + // from an OAuth flow whose handler context expired (see #882) — and + // reusing that context would either return immediately with an error + // or, worse, re-trigger expensive recovery (re-auth) on the caller's + // return path. The bounded background context lets the notification + // attempt to deliver but never blocks the caller indefinitely. + notifyCtx, cancelNotify := context.WithTimeout(context.Background(), notifyCancellationTimeout) + defer cancelNotify() + err := conn.Notify(notifyCtx, notificationCancelled, &CancelledParams{ Reason: ctx.Err().Error(), RequestID: call.ID().Raw(), })