diff --git a/codec.go b/codec.go index b872239..60707f1 100644 --- a/codec.go +++ b/codec.go @@ -4,14 +4,19 @@ import ( "context" "fmt" "net/http" + "strconv" + "strings" "github.com/tiny-go/codec" "github.com/tiny-go/errors" ) const ( - acceptHeader = "Accept" - contentTypeHeader = "Content-Type" + acceptHeader = "Accept" + contentTypeHeader = "Content-Type" + contentLengthHeader = "Content-Length" + transferEncodingHeader = "Transfer-Encoding" + defaultTransferEncoding = "identity" ) // codecKey is a private unique key that is used to put/get codec from the context. @@ -35,16 +40,22 @@ func Codec(fn errors.HandlerFunc, codecs Codecs) Middleware { var reqCodec, resCodec codec.Codec // get request codec if reqCodec = codecs.Lookup(r.Header.Get(contentTypeHeader)); reqCodec == nil { - fn(w, fmt.Sprintf("unsupported request codec: %q", r.Header.Get(contentTypeHeader)), http.StatusBadRequest) - return + if isContentTypeHeaderRequired(r) { + fn(w, fmt.Sprintf("unsupported request codec: %q", r.Header.Get(contentTypeHeader)), http.StatusBadRequest) + return + } + } else { + r = r.WithContext(context.WithValue(r.Context(), codecKey{"req"}, reqCodec)) } - r = r.WithContext(context.WithValue(r.Context(), codecKey{"req"}, reqCodec)) // get response codec if resCodec = codecs.Lookup(r.Header.Get(acceptHeader)); resCodec == nil { - fn(w, fmt.Sprintf("unsupported response codec: %q", r.Header.Get(acceptHeader)), http.StatusBadRequest) - return + if isAcceptHeaderRequired(r) { + fn(w, fmt.Sprintf("unsupported response codec: %q", r.Header.Get(acceptHeader)), http.StatusBadRequest) + return + } + } else { + r = r.WithContext(context.WithValue(r.Context(), codecKey{"res"}, resCodec)) } - r = r.WithContext(context.WithValue(r.Context(), codecKey{"res"}, resCodec)) // call the next handler next.ServeHTTP(w, r) }) @@ -62,3 +73,72 @@ func ResponseCodecFromContext(ctx context.Context) codec.Codec { codec, _ := ctx.Value(codecKey{"res"}).(codec.Codec) return codec } + +// isContentTypeHeaderRequired returns the HTTP method request body type requirement. +// By RFC7231 (https://tools.ietf.org/html/rfc7231) only POST, PUT and PATCH methods +// should contain a request body. DELETE method body is optional. +func isContentTypeHeaderRequired(r *http.Request) bool { + switch r.Method { + // Body is required + case http.MethodPost: fallthrough + case http.MethodPut: fallthrough + case http.MethodPatch: + return shouldRequestBodyBeProcessed(r, true) + // May have body, but not required + case http.MethodDelete: + return shouldRequestBodyBeProcessed(r, false) + // No body + case http.MethodGet: fallthrough + case http.MethodHead: fallthrough + case http.MethodConnect: fallthrough + case http.MethodOptions: fallthrough + case http.MethodTrace: fallthrough + default: + return false + } +} + +// isAcceptHeaderRequired returns the HTTP method response body type requirement. +// By RFC7231 (https://tools.ietf.org/html/rfc7231) only GET, POST, CONNECT, +// OPTIONS and PATCH methods should indicate the details of a response body. +// DELETE method response body is optional. +func isAcceptHeaderRequired(r *http.Request) bool { + switch r.Method { + // Body is required + case http.MethodGet: fallthrough + case http.MethodPost: fallthrough + case http.MethodConnect: fallthrough + case http.MethodOptions: fallthrough + case http.MethodPatch: + return true + // May have body, but not required + case http.MethodDelete: fallthrough + // No body + case http.MethodHead: fallthrough + case http.MethodPut: fallthrough + case http.MethodTrace: fallthrough + default: + return false + } +} + +func shouldRequestBodyBeProcessed(r *http.Request, required bool) bool { + transferEncoding := r.Header.Get(transferEncodingHeader) + hasRequestBody := transferEncoding != "" && !strings.EqualFold(transferEncoding, defaultTransferEncoding) + + hasRequestBody = hasRequestBody || func() bool { + contentLengthStr := r.Header.Get(contentLengthHeader) + if contentLengthStr != "" { + contentLength, err := strconv.Atoi(contentLengthStr) + if err != nil || contentLength < 0 { + return false + } + + return contentLength > 0 + } + + return required + }() + + return hasRequestBody +} diff --git a/codec_test.go b/codec_test.go index d77e8a7..9cd00ac 100644 --- a/codec_test.go +++ b/codec_test.go @@ -21,36 +21,109 @@ func TestCodecFromList(t *testing.T) { body string } + type Data struct { + Test string + } + cases := []testCase{ { - title: "should throw an error if request codec in not supported", + title: "should throw an error if a request codec is required but not supported", handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})(nil), request: func() *http.Request { - r, _ := http.NewRequest(http.MethodGet, "", nil) + r, _ := http.NewRequest(http.MethodPost, "", nil) r.Header.Set(contentTypeHeader, "unknown") + r.Header.Set(contentLengthHeader, "1") return r }(), code: http.StatusBadRequest, body: "unsupported request codec: \"unknown\"\n", }, { - title: "should throw an error if response codec in not supported", + title: "should ignore a request codec if not supported but not required", + handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("passed")) + }), + ), + request: func() *http.Request { + r, _ := http.NewRequest(http.MethodDelete, "", nil) + r.Header.Set(contentTypeHeader, "unknown") + return r + }(), + code: http.StatusOK, + body: "passed", + }, + { + title: "should use a request codec if supported but not required", + handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})( + BodyClose( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var data Data + RequestCodecFromContext(r.Context()).Decoder(r.Body).Decode(&data) + w.Write([]byte(data.Test)) + }), + ), + ), + request: func() *http.Request { + r, _ := http.NewRequest(http.MethodDelete, "", strings.NewReader("{\"test\":\"passed\"}\n")) + r.Header.Set(contentTypeHeader, "application/json") + r.Header.Set(contentLengthHeader, "1") + return r + }(), + code: http.StatusOK, + body: "passed", + }, + { + title: "should throw an error if response codec is required but not supported", handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})(nil), request: func() *http.Request { - r, _ := http.NewRequest(http.MethodGet, "", nil) + r, _ := http.NewRequest(http.MethodPost, "", nil) r.Header.Set(contentTypeHeader, "application/json") + r.Header.Set(contentLengthHeader, "0") r.Header.Set(acceptHeader, "unknown") return r }(), code: http.StatusBadRequest, body: "unsupported response codec: \"unknown\"\n", }, + { + title: "should ignore a response codec if not supported but not required", + handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("passed")) + }), + ), + request: func() *http.Request { + r, _ := http.NewRequest(http.MethodDelete, "", nil) + r.Header.Set(acceptHeader, "unknown") + return r + }(), + code: http.StatusOK, + body: "passed", + }, + { + title: "should use a response codec if supported but not required", + handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data := Data{ + Test: "passed", + } + ResponseCodecFromContext(r.Context()).Encoder(w).Encode(data) + }), + ), + request: func() *http.Request { + r, _ := http.NewRequest(http.MethodDelete, "", nil) + r.Header.Set(acceptHeader, "application/xml") + return r + }(), + code: http.StatusOK, + body: "passed", + }, { title: "should find corresponding codecs and handle the request successfully", handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})( BodyClose( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - type Data struct{ Test string } var data Data RequestCodecFromContext(r.Context()).Decoder(r.Body).Decode(&data) ResponseCodecFromContext(r.Context()).Encoder(w).Encode(data) @@ -58,8 +131,9 @@ func TestCodecFromList(t *testing.T) { ), ), request: func() *http.Request { - r, _ := http.NewRequest(http.MethodGet, "", strings.NewReader("{\"test\":\"passed\"}\n")) + r, _ := http.NewRequest(http.MethodPost, "", strings.NewReader("{\"test\":\"passed\"}\n")) r.Header.Set(contentTypeHeader, "application/json") + r.Header.Set(contentLengthHeader, "1") r.Header.Set(acceptHeader, "application/xml") return r }(),