diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 736587973..728d1237f 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -137,6 +137,7 @@ export class StreamableHTTPClientTransport implements Transport { private _lastUpscopingHeader?: string; // Track last upscoping header to prevent infinite upscoping. private _serverRetryMs?: number; // Server-provided retry delay from SSE retry field private _reconnectionTimeout?: ReturnType; + private _sseStreamOpened = false; // Track if SSE stream was successfully opened onclose?: () => void; onerror?: (error: Error) => void; @@ -240,6 +241,7 @@ export class StreamableHTTPClientTransport implements Transport { throw new StreamableHTTPError(response.status, `Failed to open SSE stream: ${response.statusText}`); } + this._sseStreamOpened = true; this._handleSseStream(response.body, options, true); } catch (error) { this.onerror?.(error as Error); @@ -479,10 +481,19 @@ export class StreamableHTTPClientTransport implements Transport { // Handle session ID received during initialization const sessionId = response.headers.get('mcp-session-id'); + const hadSessionId = this._sessionId !== undefined; if (sessionId) { this._sessionId = sessionId; } + // If we just received a session ID for the first time and SSE stream is not open, + // try to open it now. This handles the case where the initial SSE connection + // during start() was rejected because the server wasn't initialized yet. + // See: https://github.com/modelcontextprotocol/typescript-sdk/issues/1167 + if (sessionId && !hadSessionId && !this._sseStreamOpened) { + this._startOrAuthSse({ resumptionToken: undefined }).catch(err => this.onerror?.(err)); + } + if (!response.ok) { const text = await response.text().catch(() => null); diff --git a/test/client/streamableHttp.test.ts b/test/client/streamableHttp.test.ts index 52c8f1074..8682098b0 100644 --- a/test/client/streamableHttp.test.ts +++ b/test/client/streamableHttp.test.ts @@ -101,8 +101,19 @@ describe('StreamableHTTPClientTransport', () => { headers: new Headers({ 'content-type': 'text/event-stream', 'mcp-session-id': 'test-session-id' }) }); + // Mock the SSE stream GET request that happens after receiving session ID + (global.fetch as Mock).mockResolvedValueOnce({ + ok: false, + status: 405, + headers: new Headers(), + body: { cancel: vi.fn() } + }); + await transport.send(message); + // Allow the async SSE connection attempt to complete + await new Promise(resolve => setTimeout(resolve, 10)); + // Send a second message that should include the session ID (global.fetch as Mock).mockResolvedValueOnce({ ok: true, @@ -137,7 +148,19 @@ describe('StreamableHTTPClientTransport', () => { headers: new Headers({ 'content-type': 'text/event-stream', 'mcp-session-id': 'test-session-id' }) }); + // Mock the SSE stream GET request that happens after receiving session ID + (global.fetch as Mock).mockResolvedValueOnce({ + ok: false, + status: 405, + headers: new Headers(), + body: { cancel: vi.fn() } + }); + await transport.send(message); + + // Allow the async SSE connection attempt to complete + await new Promise(resolve => setTimeout(resolve, 10)); + expect(transport.sessionId).toBe('test-session-id'); // Now terminate the session @@ -177,8 +200,19 @@ describe('StreamableHTTPClientTransport', () => { headers: new Headers({ 'content-type': 'text/event-stream', 'mcp-session-id': 'test-session-id' }) }); + // Mock the SSE stream GET request that happens after receiving session ID + (global.fetch as Mock).mockResolvedValueOnce({ + ok: false, + status: 405, + headers: new Headers(), + body: { cancel: vi.fn() } + }); + await transport.send(message); + // Allow the async SSE connection attempt to complete + await new Promise(resolve => setTimeout(resolve, 10)); + // Now terminate the session, but server responds with 405 (global.fetch as Mock).mockResolvedValueOnce({ ok: false, diff --git a/test/integration-tests/stateManagementStreamableHttp.test.ts b/test/integration-tests/stateManagementStreamableHttp.test.ts index 672bfb92f..288a03670 100644 --- a/test/integration-tests/stateManagementStreamableHttp.test.ts +++ b/test/integration-tests/stateManagementStreamableHttp.test.ts @@ -9,6 +9,7 @@ import { ListToolsResultSchema, ListResourcesResultSchema, ListPromptsResultSchema, + ListRootsRequestSchema, LATEST_PROTOCOL_VERSION } from '../../src/types.js'; import { zodTestMatrix, type ZodMatrixEntry } from '../../src/__fixtures__/zodTestMatrix.js'; @@ -376,6 +377,52 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Clean up await transport.close(); }); + + it('should support server-initiated roots/list request', async () => { + // This test reproduces GitHub issue #1167 + // https://github.com/modelcontextprotocol/typescript-sdk/issues/1167 + // + // The bug: server.listRoots() hangs when using HTTP transport because: + // 1. Client tries to open GET SSE stream before initialization + // 2. Server rejects with 400 "Server not initialized" + // 3. Client never retries opening SSE stream after initialization + // 4. Server's send() silently returns when no SSE stream exists + // 5. listRoots() promise never resolves + + // Create client with roots capability + const client = new Client({ name: 'test-client', version: '1.0.0' }, { capabilities: { roots: { listChanged: true } } }); + + // Register handler for roots/list requests from server + client.setRequestHandler(ListRootsRequestSchema, async () => { + return { + roots: [{ uri: 'file:///home/user/project', name: 'Test Project' }] + }; + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Verify client has session ID (stateful mode) + expect(transport.sessionId).toBeDefined(); + + // Now try to call listRoots from the server + const rootsPromise = mcpServer!.server.listRoots(); + + // Use a short timeout to detect the hang + const timeoutPromise = new Promise((_, reject) => { + setTimeout(() => reject(new Error('listRoots() timed out - SSE stream not working')), 2000); + }); + + const result = await Promise.race([rootsPromise, timeoutPromise]); + + expect(result.roots).toHaveLength(1); + expect(result.roots[0]).toEqual({ + uri: 'file:///home/user/project', + name: 'Test Project' + }); + + await transport.close(); + }); }); }); }); diff --git a/test/server/index.test.ts b/test/server/index.test.ts index e434e57fc..63568a1a6 100644 --- a/test/server/index.test.ts +++ b/test/server/index.test.ts @@ -13,12 +13,14 @@ import { LATEST_PROTOCOL_VERSION, ListPromptsRequestSchema, ListResourcesRequestSchema, + ListRootsRequestSchema, ListToolsRequestSchema, type LoggingMessageNotification, McpError, NotificationSchema, RequestSchema, ResultSchema, + RootsListChangedNotificationSchema, SetLevelRequestSchema, SUPPORTED_PROTOCOL_VERSIONS, CreateTaskResultSchema @@ -3277,3 +3279,328 @@ test('should respect client task capabilities', async () => { clientTaskStore.cleanup(); }); + +describe('roots/list', () => { + test('should successfully list roots when client supports roots capability', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + roots: {} + } + } + ); + + // Register handler for roots/list + client.setRequestHandler(ListRootsRequestSchema, async () => { + return { + roots: [ + { + uri: 'file:///home/user/project', + name: 'My Project' + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const result = await server.listRoots(); + + expect(result.roots).toHaveLength(1); + expect(result.roots[0]).toEqual({ + uri: 'file:///home/user/project', + name: 'My Project' + }); + }); + + test('should handle empty roots list', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + roots: {} + } + } + ); + + // Return empty roots list + client.setRequestHandler(ListRootsRequestSchema, async () => { + return { + roots: [] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const result = await server.listRoots(); + + expect(result.roots).toHaveLength(0); + expect(result.roots).toEqual([]); + }); + + test('should handle multiple roots', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + roots: {} + } + } + ); + + const expectedRoots = [ + { uri: 'file:///home/user/project1', name: 'Project 1' }, + { uri: 'file:///home/user/project2', name: 'Project 2' }, + { uri: 'file:///var/data/shared' } + ]; + + client.setRequestHandler(ListRootsRequestSchema, async () => { + return { roots: expectedRoots }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const result = await server.listRoots(); + + expect(result.roots).toHaveLength(3); + expect(result.roots).toEqual(expectedRoots); + }); + + test('should handle roots with optional name and _meta fields', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + roots: {} + } + } + ); + + const expectedRoots = [ + // Root with all optional fields + { + uri: 'file:///home/user/project', + name: 'Full Project', + _meta: { + type: 'workspace', + priority: 1 + } + }, + // Root with only uri (minimal) + { + uri: 'file:///tmp/scratch' + }, + // Root with name but no _meta + { + uri: 'file:///var/logs', + name: 'Log Directory' + } + ]; + + client.setRequestHandler(ListRootsRequestSchema, async () => { + return { roots: expectedRoots }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const result = await server.listRoots(); + + expect(result.roots).toHaveLength(3); + expect(result.roots[0]).toEqual({ + uri: 'file:///home/user/project', + name: 'Full Project', + _meta: { + type: 'workspace', + priority: 1 + } + }); + expect(result.roots[1]).toEqual({ + uri: 'file:///tmp/scratch' + }); + expect(result.roots[2]).toEqual({ + uri: 'file:///var/logs', + name: 'Log Directory' + }); + }); + + test('should send roots list changed notification', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + roots: { + listChanged: true + } + } + } + ); + + // Track if notification was received + let notificationReceived = false; + + server.setNotificationHandler(RootsListChangedNotificationSchema, async () => { + notificationReceived = true; + }); + + client.setRequestHandler(ListRootsRequestSchema, async () => { + return { roots: [] }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Send the notification + await client.sendRootsListChanged(); + + // Give a moment for the notification to be processed + await new Promise(resolve => setTimeout(resolve, 10)); + + expect(notificationReceived).toBe(true); + }); + + test('should pass context to roots/list handler', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + roots: {} + } + } + ); + + let capturedExtra: unknown = null; + + client.setRequestHandler(ListRootsRequestSchema, async (_request, extra) => { + capturedExtra = extra; + return { roots: [] }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + await server.listRoots(); + + expect(capturedExtra).not.toBeNull(); + expect(capturedExtra).toHaveProperty('sessionId'); + expect(capturedExtra).toHaveProperty('signal'); + }); +});