diff --git a/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs index a92093246..907f9e062 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs @@ -10,15 +10,17 @@ internal sealed class StdioClientSessionTransport : StreamClientSessionTransport private readonly StdioClientTransportOptions _options; private readonly Process _process; private readonly Queue _stderrRollingLog; + private readonly DataReceivedEventHandler _errorHandler; private int _cleanedUp = 0; private readonly int? _processId; - public StdioClientSessionTransport(StdioClientTransportOptions options, Process process, string endpointName, Queue stderrRollingLog, ILoggerFactory? loggerFactory) : + public StdioClientSessionTransport(StdioClientTransportOptions options, Process process, string endpointName, Queue stderrRollingLog, DataReceivedEventHandler errorHandler, ILoggerFactory? loggerFactory) : base(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, encoding: null, endpointName, loggerFactory) { _options = options; _process = process; _stderrRollingLog = stderrRollingLog; + _errorHandler = errorHandler; try { _processId = process.Id; } catch { } } @@ -45,9 +47,13 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation /// protected override async ValueTask CleanupAsync(Exception? error = null, CancellationToken cancellationToken = default) { - // Only clean up once. + // Only run the full stdio cleanup once (handler detach, process kill, etc.). + // If another call is already handling cleanup, cancel the shutdown token + // to unblock it (e.g. if it's stuck in WaitForExitAsync) and let it + // call SetDisconnected with full StdioClientCompletionDetails. if (Interlocked.Exchange(ref _cleanedUp, 1) != 0) { + CancelShutdown(); return; } @@ -55,6 +61,10 @@ protected override async ValueTask CleanupAsync(Exception? error = null, Cancell // so create an exception with details about that. error ??= await GetUnexpectedExitExceptionAsync(cancellationToken).ConfigureAwait(false); + // Detach the stderr handler so no further ErrorDataReceived events + // are dispatched during or after process disposal. + _process.ErrorDataReceived -= _errorHandler; + // Terminate the server process (or confirm it already exited), then build // and publish strongly-typed completion details while the process handle // is still valid so we can read the exit code. @@ -89,13 +99,17 @@ protected override async ValueTask CleanupAsync(Exception? error = null, Cancell try { // The process has exited, but we still need to ensure stderr has been flushed. - // WaitForExitAsync only waits for exit; it does not guarantee that all - // ErrorDataReceived events have been dispatched. The synchronous WaitForExit() - // (no arguments) does ensure that, so call it after WaitForExitAsync completes. + // Use a bounded wait: the process is already dead, we're just draining pipe + // buffers. If the caller's token is never canceled (e.g. _shutdownCts hasn't + // been canceled yet), an unbounded wait here can hang indefinitely when the + // threadpool is slow to deliver the stderr EOF callback. #if NET - await _process.WaitForExitAsync(cancellationToken).ConfigureAwait(false); + using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + timeoutCts.CancelAfter(_options.ShutdownTimeout); + await _process.WaitForExitAsync(timeoutCts.Token).ConfigureAwait(false); +#else + _process.WaitForExit((int)_options.ShutdownTimeout.TotalMilliseconds); #endif - _process.WaitForExit(); } catch { } diff --git a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs index 24a47613b..24a7dba8e 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs @@ -59,6 +59,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = Process? process = null; bool processStarted = false; + DataReceivedEventHandler? errorHandler = null; string command = _options.Command; IList? arguments = _options.Arguments; @@ -136,7 +137,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = // few lines in a rolling log for use in exceptions. const int MaxStderrLength = 10; // keep the last 10 lines of stderr Queue stderrRollingLog = new(MaxStderrLength); - process.ErrorDataReceived += (sender, args) => + errorHandler = (sender, args) => { string? data = args.Data; if (data is not null) @@ -151,11 +152,22 @@ public async Task ConnectAsync(CancellationToken cancellationToken = stderrRollingLog.Enqueue(data); } - _options.StandardErrorLines?.Invoke(data); + try + { + _options.StandardErrorLines?.Invoke(data); + } + catch (Exception ex) + { + // Prevent exceptions in the user callback from propagating + // to the background thread that dispatches ErrorDataReceived, + // which would crash the process. + LogStderrCallbackFailed(logger, endpointName, ex); + } LogReadStderr(logger, endpointName, data); } }; + process.ErrorDataReceived += errorHandler; // We need both stdin and stdout to use a no-BOM UTF-8 encoding. On .NET Core, // we can use ProcessStartInfo.StandardOutputEncoding/StandardInputEncoding, but @@ -193,7 +205,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = process.BeginErrorReadLine(); - return new StdioClientSessionTransport(_options, process, endpointName, stderrRollingLog, _loggerFactory); + return new StdioClientSessionTransport(_options, process, endpointName, stderrRollingLog, errorHandler, _loggerFactory); } catch (Exception ex) { @@ -201,6 +213,11 @@ public async Task ConnectAsync(CancellationToken cancellationToken = try { + if (process is not null && errorHandler is not null) + { + process.ErrorDataReceived -= errorHandler; + } + DisposeProcess(process, processStarted, _options.ShutdownTimeout); } catch (Exception ex2) @@ -228,18 +245,6 @@ internal static void DisposeProcess( process.KillTree(shutdownTimeout); } - // Ensure all redirected stderr/stdout events have been dispatched - // before disposing. Only the no-arg WaitForExit() guarantees this; - // WaitForExit(int) (as used by KillTree) does not. - // This should not hang: either the process already exited on its own - // (no child processes holding handles), or KillTree killed the entire - // process tree. If it does take too long, the test infrastructure's - // own timeout will catch it. - if (!processRunning && HasExited(process)) - { - process.WaitForExit(); - } - // Invoke the callback while the process handle is still valid, // e.g. to read ExitCode before Dispose() invalidates it. beforeDispose?.Invoke(); @@ -299,6 +304,9 @@ private static string EscapeArgumentString(string argument) => [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received stderr log: '{Data}'.")] private static partial void LogReadStderr(ILogger logger, string endpointName, string data); + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} StandardErrorLines callback failed.")] + private static partial void LogStderrCallbackFailed(ILogger logger, string endpointName, Exception exception); + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} started server process with PID {ProcessId}.")] private static partial void LogTransportProcessStarted(ILogger logger, string endpointName, int processId); diff --git a/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs index 19306349f..38df5b7e6 100644 --- a/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs @@ -164,6 +164,21 @@ private async Task ProcessMessageAsync(string line, CancellationToken cancellati } } + /// + /// Cancels the shutdown token to signal that the transport is shutting down, + /// without performing any other cleanup. + /// + protected void CancelShutdown() + { + try + { + _shutdownCts?.Cancel(); + } + catch (ObjectDisposedException) + { + } + } + protected virtual async ValueTask CleanupAsync(Exception? error = null, CancellationToken cancellationToken = default) { LogTransportShuttingDown(Name); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs index 1418574a7..72d075fe7 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs @@ -82,24 +82,29 @@ public async Task RunConformanceTest(string scenario) var process = new Process { StartInfo = startInfo }; - process.OutputDataReceived += (sender, e) => + // Protect callbacks with try/catch to prevent ITestOutputHelper from + // throwing on a background thread if events arrive after the test completes. + DataReceivedEventHandler outputHandler = (sender, e) => { if (e.Data != null) { - _output.WriteLine(e.Data); + try { _output.WriteLine(e.Data); } catch { } outputBuilder.AppendLine(e.Data); } }; - process.ErrorDataReceived += (sender, e) => + DataReceivedEventHandler errorHandler = (sender, e) => { if (e.Data != null) { - _output.WriteLine(e.Data); + try { _output.WriteLine(e.Data); } catch { } errorBuilder.AppendLine(e.Data); } }; + process.OutputDataReceived += outputHandler; + process.ErrorDataReceived += errorHandler; + process.Start(); process.BeginOutputReadLine(); process.BeginErrorReadLine(); @@ -112,6 +117,8 @@ public async Task RunConformanceTest(string scenario) catch (OperationCanceledException) { process.Kill(entireProcessTree: true); + process.OutputDataReceived -= outputHandler; + process.ErrorDataReceived -= errorHandler; return ( Success: false, Output: outputBuilder.ToString(), @@ -119,6 +126,9 @@ public async Task RunConformanceTest(string scenario) ); } + process.OutputDataReceived -= outputHandler; + process.ErrorDataReceived -= errorHandler; + var output = outputBuilder.ToString(); var error = errorBuilder.ToString(); var success = process.ExitCode == 0 || HasOnlyWarnings(output, error); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs index d2501456c..e538a6f3f 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ServerConformanceTests.cs @@ -136,24 +136,29 @@ public async Task RunPendingConformanceTest_ServerSsePolling() var process = new Process { StartInfo = startInfo }; - process.OutputDataReceived += (sender, e) => + // Protect callbacks with try/catch to prevent ITestOutputHelper from + // throwing on a background thread if events arrive after the test completes. + DataReceivedEventHandler outputHandler = (sender, e) => { if (e.Data != null) { - output.WriteLine(e.Data); + try { output.WriteLine(e.Data); } catch { } outputBuilder.AppendLine(e.Data); } }; - process.ErrorDataReceived += (sender, e) => + DataReceivedEventHandler errorHandler = (sender, e) => { if (e.Data != null) { - output.WriteLine(e.Data); + try { output.WriteLine(e.Data); } catch { } errorBuilder.AppendLine(e.Data); } }; + process.OutputDataReceived += outputHandler; + process.ErrorDataReceived += errorHandler; + process.Start(); process.BeginOutputReadLine(); process.BeginErrorReadLine(); @@ -166,6 +171,8 @@ public async Task RunPendingConformanceTest_ServerSsePolling() catch (OperationCanceledException) { process.Kill(entireProcessTree: true); + process.OutputDataReceived -= outputHandler; + process.ErrorDataReceived -= errorHandler; return ( Success: false, Output: outputBuilder.ToString(), @@ -173,6 +180,9 @@ public async Task RunPendingConformanceTest_ServerSsePolling() ); } + process.OutputDataReceived -= outputHandler; + process.ErrorDataReceived -= errorHandler; + return ( Success: process.ExitCode == 0, Output: outputBuilder.ToString(), diff --git a/tests/ModelContextProtocol.ConformanceClient/Program.cs b/tests/ModelContextProtocol.ConformanceClient/Program.cs index 40dc424cb..b5e048dd0 100644 --- a/tests/ModelContextProtocol.ConformanceClient/Program.cs +++ b/tests/ModelContextProtocol.ConformanceClient/Program.cs @@ -105,85 +105,96 @@ OAuth = oauthOptions, }, loggerFactory: consoleLoggerFactory); -await using var mcpClient = await McpClient.CreateAsync(clientTransport, options, loggerFactory: consoleLoggerFactory); +try +{ + await using var mcpClient = await McpClient.CreateAsync(clientTransport, options, loggerFactory: consoleLoggerFactory); -bool success = true; + bool success = true; -switch (scenario) -{ - case "tools_call": + switch (scenario) { - var tools = await mcpClient.ListToolsAsync(); - Console.WriteLine($"Available tools: {string.Join(", ", tools.Select(t => t.Name))}"); + case "tools_call": + { + var tools = await mcpClient.ListToolsAsync(); + Console.WriteLine($"Available tools: {string.Join(", ", tools.Select(t => t.Name))}"); - // Call the "add_numbers" tool - var toolName = "add_numbers"; - Console.WriteLine($"Calling tool: {toolName}"); - var result = await mcpClient.CallToolAsync(toolName: toolName, arguments: new Dictionary + // Call the "add_numbers" tool + var toolName = "add_numbers"; + Console.WriteLine($"Calling tool: {toolName}"); + var result = await mcpClient.CallToolAsync(toolName: toolName, arguments: new Dictionary + { + { "a", 5 }, + { "b", 10 } + }); + success &= !(result.IsError == true); + break; + } + case "elicitation-sep1034-client-defaults": { - { "a", 5 }, - { "b", 10 } - }); - success &= !(result.IsError == true); - break; - } - case "elicitation-sep1034-client-defaults": - { - var tools = await mcpClient.ListToolsAsync(); - Console.WriteLine($"Available tools: {string.Join(", ", tools.Select(t => t.Name))}"); - var toolName = "test_client_elicitation_defaults"; - Console.WriteLine($"Calling tool: {toolName}"); - var result = await mcpClient.CallToolAsync(toolName: toolName, arguments: new Dictionary()); - success &= !(result.IsError == true); - break; - } - case "sse-retry": - { - var tools = await mcpClient.ListToolsAsync(); - Console.WriteLine($"Available tools: {string.Join(", ", tools.Select(t => t.Name))}"); - var toolName = "test_reconnection"; - Console.WriteLine($"Calling tool: {toolName}"); - var result = await mcpClient.CallToolAsync(toolName: toolName, arguments: new Dictionary()); - success &= !(result.IsError == true); - break; - } - case "auth/scope-step-up": - { - // Just testing that we can authenticate and list tools - var tools = await mcpClient.ListToolsAsync(); - Console.WriteLine($"Available tools: {string.Join(", ", tools.Select(t => t.Name))}"); - - // Call the "test_tool" tool - var toolName = "test-tool"; - Console.WriteLine($"Calling tool: {toolName}"); - var result = await mcpClient.CallToolAsync(toolName: toolName, arguments: new Dictionary + var tools = await mcpClient.ListToolsAsync(); + Console.WriteLine($"Available tools: {string.Join(", ", tools.Select(t => t.Name))}"); + var toolName = "test_client_elicitation_defaults"; + Console.WriteLine($"Calling tool: {toolName}"); + var result = await mcpClient.CallToolAsync(toolName: toolName, arguments: new Dictionary()); + success &= !(result.IsError == true); + break; + } + case "sse-retry": { - { "foo", "bar" }, - }); - success &= !(result.IsError == true); - break; - } - case "auth/scope-retry-limit": - { - // Try to list tools - this triggers the auth flow that always fails with 403. - // The test validates the client doesn't retry indefinitely. - try + var tools = await mcpClient.ListToolsAsync(); + Console.WriteLine($"Available tools: {string.Join(", ", tools.Select(t => t.Name))}"); + var toolName = "test_reconnection"; + Console.WriteLine($"Calling tool: {toolName}"); + var result = await mcpClient.CallToolAsync(toolName: toolName, arguments: new Dictionary()); + success &= !(result.IsError == true); + break; + } + case "auth/scope-step-up": { - await mcpClient.ListToolsAsync(); + // Just testing that we can authenticate and list tools + var tools = await mcpClient.ListToolsAsync(); + Console.WriteLine($"Available tools: {string.Join(", ", tools.Select(t => t.Name))}"); + + // Call the "test_tool" tool + var toolName = "test-tool"; + Console.WriteLine($"Calling tool: {toolName}"); + var result = await mcpClient.CallToolAsync(toolName: toolName, arguments: new Dictionary + { + { "foo", "bar" }, + }); + success &= !(result.IsError == true); + break; } - catch (Exception ex) + case "auth/scope-retry-limit": { - Console.WriteLine($"Expected auth failure: {ex.Message}"); + // Try to list tools - this triggers the auth flow that always fails with 403. + // The test validates the client doesn't retry indefinitely. + try + { + await mcpClient.ListToolsAsync(); + } + catch (Exception ex) + { + Console.WriteLine($"Expected auth failure: {ex.Message}"); + } + break; } - break; + default: + // No extra processing for other scenarios + break; } - default: - // No extra processing for other scenarios - break; -} -// Exit code 0 on success, 1 on failure -return success ? 0 : 1; + // Exit code 0 on success, 1 on failure + return success ? 0 : 1; +} +catch (Exception ex) +{ + // Report the error to stderr and exit with a non-zero code rather than + // crashing the process with an unhandled exception. An unhandled exception + // generates a crash dump which can abort the parent test host. + Console.Error.WriteLine($"Conformance client failed: {ex}"); + return 1; +} // Copied from ProtectedMcpClient sample // Simulate a user opening the browser and logging in diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs index a52746b88..6af33d8d2 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs @@ -9,16 +9,17 @@ namespace ModelContextProtocol.Tests.Configuration; public class McpServerBuilderExtensionsTransportsTests { [Fact] - public void WithStdioServerTransport_Sets_Transport() + public void WithStdioServerTransport_Registers_Transport() { var services = new ServiceCollection(); services.AddMcpServer().WithStdioServerTransport(); - var transportServiceType = services.FirstOrDefault(s => s.ServiceType == typeof(ITransport)); - Assert.NotNull(transportServiceType); - - var serviceProvider = services.BuildServiceProvider(); - Assert.IsType(serviceProvider.GetRequiredService()); + // Verify StdioServerTransport is registered for ITransport, but don't resolve it — + // doing so opens Console.OpenStandardInput() which permanently blocks a thread pool + // thread on the test host's stdin. StdioServerTransport should only be used in a + // dedicated child process, not in-process. + var transportDescriptor = services.FirstOrDefault(s => s.ServiceType == typeof(ITransport)); + Assert.NotNull(transportDescriptor); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs index 40165d58c..689aba9d0 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerOptionsSetupTests.cs @@ -181,9 +181,10 @@ public async Task ServerCapabilities_WithManualResourceSubscribeCapability_AndWi }; }) .WithResources() - .WithStdioServerTransport(); + .WithStreamServerTransport(Stream.Null, Stream.Null); - var options = services.BuildServiceProvider().GetRequiredService>().Value; + await using var sp = services.BuildServiceProvider(); + var options = sp.GetRequiredService>().Value; // The options should preserve the user's manually set capabilities Assert.NotNull(options.Capabilities?.Resources); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerResourceCapabilityIntegrationTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerResourceCapabilityIntegrationTests.cs index f9b6217cf..1d9fe554a 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerResourceCapabilityIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerResourceCapabilityIntegrationTests.cs @@ -106,9 +106,9 @@ public async Task Resources_AreExposed_WhenSubscribeCapabilitySetInAddMcpServerO }; }) .WithResources() - .WithStdioServerTransport(); + .WithStreamServerTransport(Stream.Null, Stream.Null); - var serviceProvider = services.BuildServiceProvider(); + await using var serviceProvider = services.BuildServiceProvider(); var mcpOptions = serviceProvider.GetRequiredService>().Value; // Verify capabilities are preserved @@ -122,15 +122,15 @@ public async Task Resources_AreExposed_WhenSubscribeCapabilitySetInAddMcpServerO } [Fact] - public void ResourcesCapability_IsCreated_WhenOnlyResourcesAreProvided() + public async Task ResourcesCapability_IsCreated_WhenOnlyResourcesAreProvided() { // Test that ResourcesCapability is created even without handlers or manual setting var services = new ServiceCollection(); var builder = services.AddMcpServer() .WithResources() - .WithStdioServerTransport(); + .WithStreamServerTransport(Stream.Null, Stream.Null); - var serviceProvider = services.BuildServiceProvider(); + await using var serviceProvider = services.BuildServiceProvider(); var mcpOptions = serviceProvider.GetRequiredService>().Value; // Resources are registered diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs index 10116e70e..a45d19604 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs @@ -15,29 +15,29 @@ public McpServerLoggingLevelTests() } [Fact] - public void CanCreateServerWithLoggingLevelHandler() + public async Task CanCreateServerWithLoggingLevelHandler() { var services = new ServiceCollection(); services.AddMcpServer() - .WithStdioServerTransport() + .WithStreamServerTransport(Stream.Null, Stream.Null) .WithSetLoggingLevelHandler(async (ctx, ct) => new EmptyResult()); - var provider = services.BuildServiceProvider(); + await using var provider = services.BuildServiceProvider(); provider.GetRequiredService(); } [Fact] - public void AddingLoggingLevelHandlerSetsLoggingCapability() + public async Task AddingLoggingLevelHandlerSetsLoggingCapability() { var services = new ServiceCollection(); services.AddMcpServer() - .WithStdioServerTransport() + .WithStreamServerTransport(Stream.Null, Stream.Null) .WithSetLoggingLevelHandler(async (ctx, ct) => new EmptyResult()); - var provider = services.BuildServiceProvider(); + await using var provider = services.BuildServiceProvider(); var server = provider.GetRequiredService(); @@ -46,12 +46,12 @@ public void AddingLoggingLevelHandlerSetsLoggingCapability() } [Fact] - public void ServerWithoutCallingLoggingLevelHandlerDoesNotSetLoggingCapability() + public async Task ServerWithoutCallingLoggingLevelHandlerDoesNotSetLoggingCapability() { var services = new ServiceCollection(); services.AddMcpServer() - .WithStdioServerTransport(); - var provider = services.BuildServiceProvider(); + .WithStreamServerTransport(Stream.Null, Stream.Null); + await using var provider = services.BuildServiceProvider(); var server = provider.GetRequiredService(); Assert.Null(server.ServerOptions.Capabilities?.Logging); } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs index dfcc0bd4f..6a67bac9a 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs @@ -29,12 +29,12 @@ public McpServerResourceTests() } [Fact] - public void CanCreateServerWithResource() + public async Task CanCreateServerWithResource() { var services = new ServiceCollection(); services.AddMcpServer() - .WithStdioServerTransport() + .WithStreamServerTransport(Stream.Null, Stream.Null) .WithListResourcesHandler(async (ctx, ct) => { return new ListResourcesResult @@ -58,19 +58,19 @@ public void CanCreateServerWithResource() }; }); - var provider = services.BuildServiceProvider(); + await using var provider = services.BuildServiceProvider(); provider.GetRequiredService(); } [Fact] - public void CanCreateServerWithResourceTemplates() + public async Task CanCreateServerWithResourceTemplates() { var services = new ServiceCollection(); services.AddMcpServer() - .WithStdioServerTransport() + .WithStreamServerTransport(Stream.Null, Stream.Null) .WithListResourceTemplatesHandler(async (ctx, ct) => { return new ListResourceTemplatesResult @@ -94,17 +94,17 @@ public void CanCreateServerWithResourceTemplates() }; }); - var provider = services.BuildServiceProvider(); + await using var provider = services.BuildServiceProvider(); provider.GetRequiredService(); } [Fact] - public void CreatingReadHandlerWithNoListHandlerSucceeds() + public async Task CreatingReadHandlerWithNoListHandlerSucceeds() { var services = new ServiceCollection(); services.AddMcpServer() - .WithStdioServerTransport() + .WithStreamServerTransport(Stream.Null, Stream.Null) .WithReadResourceHandler(async (ctx, ct) => { return new ReadResourceResult @@ -117,7 +117,7 @@ public void CreatingReadHandlerWithNoListHandlerSucceeds() }] }; }); - var sp = services.BuildServiceProvider(); + await using var sp = services.BuildServiceProvider(); sp.GetRequiredService(); } diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs index d84ea9377..1a999fd14 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -59,6 +59,17 @@ public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() Assert.Contains(id, sb.ToString()); } + [Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(IsStdErrCallbackSupported))] + public async Task CreateAsync_StdErrCallbackThrows_DoesNotCrashProcess() + { + StdioClientTransport transport = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? + new(new() { Command = "cmd", Arguments = ["/c", "echo fail >&2 & exit /b 1"], StandardErrorLines = _ => throw new InvalidOperationException("boom") }, LoggerFactory) : + new(new() { Command = "sh", Arguments = ["-c", "echo fail >&2; exit 1"], StandardErrorLines = _ => throw new InvalidOperationException("boom") }, LoggerFactory); + + // Should throw IOException for the failed server, not crash the host process. + await Assert.ThrowsAnyAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + } + [Theory] [InlineData(null)] [InlineData("argument with spaces")] diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 22ac43d95..e47269686 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -27,8 +27,11 @@ public StdioServerTransportTests(ITestOutputHelper testOutputHelper) [Fact] public async Task Constructor_Should_Initialize_With_Valid_Parameters() { - // Act - await using var transport = new StdioServerTransport(_serverOptions); + // Use StreamServerTransport with Stream.Null rather than StdioServerTransport. + // StdioServerTransport opens Console.OpenStandardInput() which permanently + // blocks a thread pool thread on the test host's stdin. StdioServerTransport + // should only be instantiated in a dedicated child process. + await using var transport = new StreamServerTransport(Stream.Null, Stream.Null, _serverOptions.ServerInfo?.Name); // Assert Assert.NotNull(transport);