From 5fe36a290fe6831aa00f5d5ba32f98cc45e3cf59 Mon Sep 17 00:00:00 2001 From: Maciek Kisiel Date: Fri, 10 Apr 2026 08:22:47 +0000 Subject: [PATCH] mcp: add DNS rebinding and cross origin protections to SSE transport --- docs/mcpgodebug.md | 7 +- internal/docs/mcpgodebug.src.md | 7 +- mcp/sse.go | 57 ++++++++++- mcp/sse_test.go | 173 ++++++++++++++++++++++++++++++++ mcp/streamable.go | 6 +- 5 files changed, 236 insertions(+), 14 deletions(-) diff --git a/docs/mcpgodebug.md b/docs/mcpgodebug.md index 714abfe8..7c6773dc 100644 --- a/docs/mcpgodebug.md +++ b/docs/mcpgodebug.md @@ -24,7 +24,7 @@ Options listed below will be removed in the 1.6.0 version of the SDK. - `disablecrossoriginprotection` added. If set to `1`, newly added cross-origin protection will be disabled. The default behavior was changed to enable - cross-origin protection. + cross-origin protection. **Removal of this option was postponed until 1.7.0.** ### 1.4.0 @@ -37,5 +37,6 @@ Options listed below will be removed in the 1.6.0 version of the SDK. - `disablelocalhostprotection` added. If set to `1`, newly added DNS rebinding protection will be disabled. The default behavior was changed to enable DNS rebinding protection. The protection can also be disabled by setting the - `DisableLocalhostProtection` field in the `StreamableHTTPOptions` struct to - `true`, which is the recommended way to disable the protection long term. + `DisableLocalhostProtection` field in the `StreamableHTTPOptions` or + `SSEOptions` struct to `true`, which is the recommended way to disable + the protection long term. **Removal of this option was postponed until 1.7.0.** diff --git a/internal/docs/mcpgodebug.src.md b/internal/docs/mcpgodebug.src.md index e13500b5..592e9238 100644 --- a/internal/docs/mcpgodebug.src.md +++ b/internal/docs/mcpgodebug.src.md @@ -23,7 +23,7 @@ Options listed below will be removed in the 1.6.0 version of the SDK. - `disablecrossoriginprotection` added. If set to `1`, newly added cross-origin protection will be disabled. The default behavior was changed to enable - cross-origin protection. + cross-origin protection. **Removal of this option was postponed until 1.7.0.** ### 1.4.0 @@ -36,5 +36,6 @@ Options listed below will be removed in the 1.6.0 version of the SDK. - `disablelocalhostprotection` added. If set to `1`, newly added DNS rebinding protection will be disabled. The default behavior was changed to enable DNS rebinding protection. The protection can also be disabled by setting the - `DisableLocalhostProtection` field in the `StreamableHTTPOptions` struct to - `true`, which is the recommended way to disable the protection long term. \ No newline at end of file + `DisableLocalhostProtection` field in the `StreamableHTTPOptions` or + `SSEOptions` struct to `true`, which is the recommended way to disable + the protection long term. **Removal of this option was postponed until 1.7.0.** diff --git a/mcp/sse.go b/mcp/sse.go index e57dad10..f8f156d8 100644 --- a/mcp/sse.go +++ b/mcp/sse.go @@ -10,11 +10,13 @@ import ( "crypto/rand" "fmt" "io" + "net" "net/http" "net/url" "sync" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" + "github.com/modelcontextprotocol/go-sdk/internal/util" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -52,9 +54,25 @@ type SSEHandler struct { } // SSEOptions specifies options for an [SSEHandler]. -// for now, it is empty, but may be extended in future. -// https://github.com/modelcontextprotocol/go-sdk/issues/507 -type SSEOptions struct{} +type SSEOptions struct { + // DisableLocalhostProtection disables automatic DNS rebinding protection. + // By default, requests arriving via a localhost address (127.0.0.1, [::1]) + // that have a non-localhost Host header are rejected with 403 Forbidden. + // This protects against DNS rebinding attacks regardless of whether the + // server is listening on localhost specifically or on 0.0.0.0. + // + // Only disable this if you understand the security implications. + // See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise + DisableLocalhostProtection bool + + // CrossOriginProtection allows to customize cross-origin protection. + // The deny handler set in the CrossOriginProtection through SetDenyHandler + // is ignored. + // If nil, default (zero-value) cross-origin protection will be used. + // Use `disablecrossoriginprotection` MCPGODEBUG compatibility parameter + // to disable the default protection until v1.7.0. + CrossOriginProtection *http.CrossOriginProtection +} // NewSSEHandler returns a new [SSEHandler] that creates and manages MCP // sessions created via incoming HTTP requests. @@ -79,6 +97,10 @@ func NewSSEHandler(getServer func(request *http.Request) *Server, opts *SSEOptio s.opts = *opts } + if s.opts.CrossOriginProtection == nil { + s.opts.CrossOriginProtection = &http.CrossOriginProtection{} + } + return s } @@ -179,9 +201,34 @@ func (t *SSEServerTransport) Connect(context.Context) (Connection, error) { } func (h *SSEHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - sessionID := req.URL.Query().Get("sessionid") + // DNS rebinding protection: auto-enabled for localhost servers. + // See: https://modelcontextprotocol.io/specification/2025-11-25/basic/security_best_practices#local-mcp-server-compromise + if !h.opts.DisableLocalhostProtection && disablelocalhostprotection != "1" { + if localAddr, ok := req.Context().Value(http.LocalAddrContextKey).(net.Addr); ok && localAddr != nil { + if util.IsLoopback(localAddr.String()) && !util.IsLoopback(req.Host) { + http.Error(w, fmt.Sprintf("Forbidden: invalid Host header %q", req.Host), http.StatusForbidden) + return + } + } + } - // TODO: consider checking Content-Type here. For now, we are lax. + if disablecrossoriginprotection != "1" { + // Verify the 'Origin' header to protect against CSRF attacks. + if err := h.opts.CrossOriginProtection.Check(req); err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } + // Validate 'Content-Type' header. + if req.Method == http.MethodPost { + contentType := req.Header.Get("Content-Type") + if contentType != "application/json" { + http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType) + return + } + } + } + + sessionID := req.URL.Query().Get("sessionid") // For POST requests, the message body is a message to send to a session. if req.Method == http.MethodPost { diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 8746cc8b..fe230a51 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -9,8 +9,10 @@ import ( "context" "fmt" "io" + "net" "net/http" "net/http/httptest" + "strings" "sync/atomic" "testing" @@ -221,3 +223,174 @@ func TestSSE405AllowHeader(t *testing.T) { }) } } + +// TestSSELocalhostProtection verifies that DNS rebinding protection +// is automatically enabled for localhost servers. +func TestSSELocalhostProtection(t *testing.T) { + server := NewServer(testImpl, nil) + + tests := []struct { + name string + listenAddr string + hostHeader string + disableProtection bool + wantStatus int + }{ + { + name: "127.0.0.1 accepts 127.0.0.1", + listenAddr: "127.0.0.1:0", + hostHeader: "127.0.0.1:1234", + wantStatus: http.StatusOK, + }, + { + name: "127.0.0.1 accepts localhost", + listenAddr: "127.0.0.1:0", + hostHeader: "localhost:1234", + wantStatus: http.StatusOK, + }, + { + name: "127.0.0.1 rejects evil.com", + listenAddr: "127.0.0.1:0", + hostHeader: "evil.com", + wantStatus: http.StatusForbidden, + }, + { + name: "127.0.0.1 rejects evil.com:80", + listenAddr: "127.0.0.1:0", + hostHeader: "evil.com:80", + wantStatus: http.StatusForbidden, + }, + { + name: "127.0.0.1 rejects localhost.evil.com", + listenAddr: "127.0.0.1:0", + hostHeader: "localhost.evil.com", + wantStatus: http.StatusForbidden, + }, + { + name: "0.0.0.0 via localhost rejects evil.com", + listenAddr: "0.0.0.0:0", + hostHeader: "evil.com", + wantStatus: http.StatusForbidden, + }, + { + name: "disabled accepts evil.com", + listenAddr: "127.0.0.1:0", + hostHeader: "evil.com", + disableProtection: true, + wantStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := &SSEOptions{ + DisableLocalhostProtection: tt.disableProtection, + } + handler := NewSSEHandler(func(req *http.Request) *Server { return server }, opts) + + listener, err := net.Listen("tcp", tt.listenAddr) + if err != nil { + t.Fatalf("Failed to listen on %s: %v", tt.listenAddr, err) + } + defer listener.Close() + + srv := &http.Server{Handler: handler} + go srv.Serve(listener) + defer srv.Close() + + // Use a GET request since it's the entry point for SSE sessions. + // For accepted requests, the response will be a hanging SSE stream, + // but we only need to check the initial status code. + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s", listener.Addr().String()), nil) + if err != nil { + t.Fatal(err) + } + req.Host = tt.hostHeader + req.Header.Set("Accept", "text/event-stream") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if got := resp.StatusCode; got != tt.wantStatus { + t.Errorf("Status code: got %d, want %d", got, tt.wantStatus) + } + }) + } +} + +func TestSSEOriginProtection(t *testing.T) { + server := NewServer(testImpl, nil) + + tests := []struct { + name string + protection *http.CrossOriginProtection + requestOrigin string + wantStatusCode int + }{ + { + name: "default protection with Origin header", + protection: nil, + requestOrigin: "https://example.com", + wantStatusCode: http.StatusForbidden, + }, + { + name: "custom protection with trusted origin and same Origin", + protection: func() *http.CrossOriginProtection { + p := http.NewCrossOriginProtection() + if err := p.AddTrustedOrigin("https://example.com"); err != nil { + t.Fatal(err) + } + return p + }(), + requestOrigin: "https://example.com", + wantStatusCode: http.StatusNotFound, // origin accepted; session not found + }, + { + name: "custom protection with trusted origin and different Origin", + protection: func() *http.CrossOriginProtection { + p := http.NewCrossOriginProtection() + if err := p.AddTrustedOrigin("https://example.com"); err != nil { + t.Fatal(err) + } + return p + }(), + requestOrigin: "https://malicious.com", + wantStatusCode: http.StatusForbidden, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := &SSEOptions{ + CrossOriginProtection: tt.protection, + } + handler := NewSSEHandler(func(req *http.Request) *Server { return server }, opts) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + // Use POST with a valid session-like URL to test origin protection + // without creating a hanging GET connection. + reqReader := strings.NewReader(`{"jsonrpc":"2.0","id":1,"method":"ping"}`) + req, err := http.NewRequest(http.MethodPost, httpServer.URL+"?sessionid=nonexistent", reqReader) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Origin", tt.requestOrigin) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if got := resp.StatusCode; got != tt.wantStatusCode { + body, _ := io.ReadAll(resp.Body) + t.Errorf("Status code: got %d, want %d (body: %s)", got, tt.wantStatusCode, body) + } + }) + } +} diff --git a/mcp/streamable.go b/mcp/streamable.go index 76f2b0e4..b4cd693e 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -182,7 +182,7 @@ type StreamableHTTPOptions struct { // is ignored. // If nil, default (zero-value) cross-origin protection will be used. // Use `disablecrossoriginprotection` MCPGODEBUG compatibility parameter - // to disable the default protection until v1.6.0. + // to disable the default protection until v1.7.0. CrossOriginProtection *http.CrossOriginProtection } @@ -235,14 +235,14 @@ func (h *StreamableHTTPHandler) closeAll() { // disablelocalhostprotection is a compatibility parameter that allows to disable // DNS rebinding protection, which was added in the 1.4.0 version of the SDK. // See the documentation for the mcpgodebug package for instructions how to enable it. -// The option will be removed in the 1.6.0 version of the SDK. +// The option will be removed in the 1.7.0 version of the SDK. var disablelocalhostprotection = mcpgodebug.Value("disablelocalhostprotection") // disablecrossoriginprotection is a compatibility parameter that allows to disable // the verification of the 'Origin' and 'Content-Type' headers, which was added in // the 1.4.1 version of the SDK. See the documentation for the mcpgodebug package // for instructions how to enable it. -// The option will be removed in the 1.6.0 version of the SDK. +// The option will be removed in the 1.7.0 version of the SDK. var disablecrossoriginprotection = mcpgodebug.Value("disablecrossoriginprotection") func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {