From 3cc35038d21f1e8ed963972964c54a14feccd264 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Thu, 9 Apr 2026 15:47:52 -0700 Subject: [PATCH] mcp: accept parameterized Content-Type types Use a shared helper for Content-Type parsing in streamable transport request validation and client response handling. --- mcp/streamable.go | 21 +++++++++++++-------- mcp/streamable_client_test.go | 3 ++- mcp/streamable_test.go | 28 +++++++++++++++++++++++++--- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 76f2b0e4..292eed07 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -264,12 +264,9 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque return } // Validate 'Content-Type' header. - if req.Method == http.MethodPost { - mediaType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type")) - if err != nil || mediaType != "application/json" { - http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType) - return - } + if req.Method == http.MethodPost && baseMediaType(req.Header.Get("Content-Type")) != "application/json" { + http.Error(w, "Content-Type must be 'application/json'", http.StatusUnsupportedMediaType) + return } } @@ -560,6 +557,14 @@ func streamableAccepts(values []string) (jsonOK, streamOK bool) { return jsonOK, streamOK } +func baseMediaType(value string) string { + mediaType, _, err := mime.ParseMediaType(value) + if err != nil { + return "" + } + return mediaType +} + // A StreamableServerTransport implements the server side of the MCP streamable // transport. // @@ -1671,7 +1676,7 @@ func (c *streamableClientConn) connectStandaloneSSE() { resp.Body.Close() return } - if resp.Header.Get("Content-Type") != "text/event-stream" { + if baseMediaType(resp.Header.Get("Content-Type")) != "text/event-stream" { // modelcontextprotocol/go-sdk#736: some servers return 200 OK or redirect with // non-SSE content type instead of text/event-stream for the standalone // SSE stream. @@ -1855,7 +1860,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return nil } - contentType := strings.TrimSpace(strings.SplitN(resp.Header.Get("Content-Type"), ";", 2)[0]) + contentType := baseMediaType(resp.Header.Get("Content-Type")) switch contentType { case "application/json": go c.handleJSON(requestSummary, resp) diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 43564fd3..9da8aeca 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -264,6 +264,7 @@ func TestStreamableClientGETHandling(t *testing.T) { contentType string }{ {http.StatusOK, "", "text/event-stream"}, + {http.StatusOK, "", "text/event-stream; charset=utf-8"}, {http.StatusMethodNotAllowed, "", "text/event-stream"}, //// The client error status code is not treated as an error in non-strict //// mode. @@ -274,7 +275,7 @@ func TestStreamableClientGETHandling(t *testing.T) { } for _, test := range tests { - t.Run(fmt.Sprintf("status=%d", test.status), func(t *testing.T) { + t.Run(fmt.Sprintf("status=%d content_type=%q", test.status, test.contentType), func(t *testing.T) { fake := &fakeStreamableServer{ t: t, responses: fakeResponses{ diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 36002775..592981fc 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1506,9 +1506,9 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, newSessionID := resp.Header.Get(sessionIDHeader) - contentType := resp.Header.Get("Content-Type") + contentType := baseMediaType(resp.Header.Get("Content-Type")) var respBody []byte - if strings.HasPrefix(contentType, "text/event-stream") { + if contentType == "text/event-stream" { r := readerInto{resp.Body, new(bytes.Buffer)} for evt, err := range scanEvents(r) { if err != nil { @@ -1525,7 +1525,7 @@ func (s streamableRequest) do(ctx context.Context, serverURL, sessionID string, } } respBody = r.w.Bytes() - } else if strings.HasPrefix(contentType, "application/json") { + } else if contentType == "application/json" { data, err := io.ReadAll(resp.Body) if err != nil { return newSessionID, resp.StatusCode, nil, fmt.Errorf("reading json body: %w", err) @@ -2047,6 +2047,28 @@ func TestStreamableGETWithoutEventStreamAccept(t *testing.T) { } } +func TestBaseMediaType(t *testing.T) { + tests := []struct { + name string + value string + want string + }{ + {name: "empty", want: ""}, + {name: "json", value: "application/json", want: "application/json"}, + {name: "json with params", value: "Application/JSON; charset=utf-8", want: "application/json"}, + {name: "event stream with params", value: "Text/Event-Stream; charset=utf-8", want: "text/event-stream"}, + {name: "invalid", value: "application/json; charset", want: ""}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := baseMediaType(test.value); got != test.want { + t.Errorf("baseMediaType(%q) = %q, want %q", test.value, got, test.want) + } + }) + } +} + func TestStreamableClientContextPropagation(t *testing.T) { type contextKey string const testKey = contextKey("test-key")