diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index a958cba..700166e 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -20,12 +20,12 @@ jobs: test: strategy: matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ ubuntu-latest, macos-latest, windows-latest ] # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Echo tests with last four major releases (unless there are pressing vulnerabilities) # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when # we derive from last four major releases promise. - go: ["1.24", "1.25"] + go: [ "1.25" ] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: diff --git a/CHANGELOG.md b/CHANGELOG.md index 54efdb8..62bb5ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## v5.0.0 - 2026-01.18 + +* Echo v5 support + + ## v4.4.0 - 2025-11-20 **Enhancements** diff --git a/README.md b/README.md index 9c6cc48..a1b525a 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,13 @@ as JWT implementation. This repository does not use semantic versioning. MAJOR version tracks which Echo version should be used. MINOR version tracks API changes (possibly backwards incompatible) and PATCH version is incremented for fixes. -NB: When `golang-jwt` MAJOR version changes this library will release MINOR version with **breaking change**. Always +NB: When `golang-jwt` MAJOR version changes this library will release MINOR version with **breaking change**. Always add at least one integration test in your project. -For Echo `v4` use `v4.x.y` releases. +For Echo `v5` use `v5.x.y` releases. Minimal needed Echo versions: + +* `v5.0.0` needs Echo `v5.0.0+` * `v4.0.0` needs Echo `v4.7.0+` `main` branch is compatible with the latest Echo version. @@ -26,21 +28,25 @@ Minimal needed Echo versions: ## Usage Add JWT middleware dependency with go modules + ```bash -go get github.com/labstack/echo-jwt/v4 +go get github.com/labstack/echo-jwt/v5 ``` Use as import statement + ```go -import "github.com/labstack/echo-jwt/v4" +import "github.com/labstack/echo-jwt/v5" ``` Add middleware in simplified form, by providing only the secret key + ```go e.Use(echojwt.JWT([]byte("secret"))) ``` Add middleware with configuration options + ```go e.Use(echojwt.WithConfig(echojwt.Config{ // ... @@ -50,15 +56,16 @@ e.Use(echojwt.WithConfig(echojwt.Config{ ``` Extract token in handler + ```go import "github.com/golang-jwt/jwt/v5" // ... -e.GET("/", func(c echo.Context) error { - token, ok := c.Get("user").(*jwt.Token) // by default token is stored under `user` key - if !ok { - return errors.New("JWT token missing or invalid") +e.GET("/", func(c *echo.Context) error { + token, err := echo.ContextGet[*jwt.Token](c,"user") + if err != nil { + return echo.ErrUnauthorized.Wrap(err) } claims, ok := token.Claims.(jwt.MapClaims) // by default claims is of type `jwt.MapClaims` if !ok { @@ -70,8 +77,12 @@ e.GET("/", func(c echo.Context) error { ## IMPORTANT: Integration Testing with JWT Library -Ensure that your project includes at least one integration test to detect changes in major versions of the `golang-jwt/jwt` library early. -This is crucial because type assertions like `token := c.Get("user").(*jwt.Token)` may fail silently if the imported version of the JWT library (e.g., `import "github.com/golang-jwt/jwt/v5"`) differs from the version used internally by dependencies (e.g., echo-jwt may now use `v6`). Such discrepancies can lead to invalid casts, causing your handlers to panic or throw errors. Integration tests help safeguard against these version mismatches. +Ensure that your project includes at least one integration test to detect changes in major versions of the +`golang-jwt/jwt` library early. +This is crucial because type assertions like `token := c.Get("user").(*jwt.Token)` may fail silently if the imported +version of the JWT library (e.g., `import "github.com/golang-jwt/jwt/v5"`) differs from the version used internally by +dependencies (e.g., echo-jwt may now use `v6`). Such discrepancies can lead to invalid casts, causing your handlers to +panic or throw errors. Integration tests help safeguard against these version mismatches. ```go func TestIntegrationMiddlewareWithHandler(t *testing.T) { @@ -97,55 +108,58 @@ func TestIntegrationMiddlewareWithHandler(t *testing.T) { } ``` - ## Full example ```go package main import ( - "errors" - "github.com/golang-jwt/jwt/v5" - "github.com/labstack/echo-jwt/v4" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - "log" - "net/http" -) + "errors" + "log/slog" -func main() { - e := echo.New() - e.Use(middleware.Logger()) - e.Use(middleware.Recover()) + "github.com/golang-jwt/jwt/v5" + "github.com/labstack/echo-jwt/v5" + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" - e.Use(echojwt.WithConfig(echojwt.Config{ - SigningKey: []byte("secret"), - })) + "net/http" +) - e.GET("/", func(c echo.Context) error { - token, ok := c.Get("user").(*jwt.Token) // by default token is stored under `user` key - if !ok { - return errors.New("JWT token missing or invalid") - } - claims, ok := token.Claims.(jwt.MapClaims) // by default claims is of type `jwt.MapClaims` - if !ok { - return errors.New("failed to cast claims as jwt.MapClaims") - } - return c.JSON(http.StatusOK, claims) - }) - - if err := e.Start(":8080"); err != http.ErrServerClosed { - log.Fatal(err) - } +func main() { + e := echo.New() + e.Use(middleware.RequestLogger()) + e.Use(middleware.Recover()) + + e.Use(echojwt.WithConfig(echojwt.Config{ + SigningKey: []byte("secret"), + })) + + e.GET("/", func(c *echo.Context) error { + token, err := echo.ContextGet[*jwt.Token](c, "user") + if err != nil { + return echo.ErrUnauthorized.Wrap(err) + } + claims, ok := token.Claims.(jwt.MapClaims) // by default claims is of type `jwt.MapClaims` + if !ok { + return errors.New("failed to cast claims as jwt.MapClaims") + } + return c.JSON(http.StatusOK, claims) + }) + + if err := e.Start(":8080"); err != nil { + slog.Error("Failed to start server", "error", err) + } } ``` Test with + ```bash curl -v -H "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" http://localhost:8080 ``` Output should be + ```bash * Trying 127.0.0.1:8080... * Connected to localhost (127.0.0.1) port 8080 (#0) diff --git a/extractors.go b/extractors.go deleted file mode 100644 index f72537e..0000000 --- a/extractors.go +++ /dev/null @@ -1,205 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: © 2016 LabStack and Echo contributors - -package echojwt - -import ( - "errors" - "fmt" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - "net/textproto" - "strings" -) - -const ( - // extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion - // attack vector - extractorLimit = 20 -) - -// TokenExtractionError is catch all type for all errors that occur when the token is extracted from the request. This -// helps to distinguish extractor errors from token parsing errors even if custom extractors or token parsing functions -// are being used that have their own custom errors. -type TokenExtractionError struct { - Err error -} - -// Is checks if target error is same as TokenExtractionError -func (e TokenExtractionError) Is(target error) bool { return target == ErrJWTMissing } // to provide some compatibility with older error handling logic - -func (e *TokenExtractionError) Error() string { return e.Err.Error() } -func (e *TokenExtractionError) Unwrap() error { return e.Err } - -var errHeaderExtractorValueMissing = errors.New("missing value in request header") -var errHeaderExtractorValueInvalid = errors.New("invalid value in request header") -var errQueryExtractorValueMissing = errors.New("missing value in the query string") -var errParamExtractorValueMissing = errors.New("missing value in path params") -var errCookieExtractorValueMissing = errors.New("missing value in cookies") -var errFormExtractorValueMissing = errors.New("missing value in the form") - -// CreateExtractors creates ValuesExtractors from given lookups. -// Lookups is a string in the form of ":" or ":,:" that is used -// to extract key from the request. -// Possible values: -// - "header:" or "header::" -// `` is argument value to cut/trim prefix of the extracted value. This is useful if header -// value has static prefix like `Authorization: ` where part that we -// want to cut is ` ` note the space at the end. -// In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. -// - "query:" -// - "param:" -// - "form:" -// - "cookie:" -// -// Multiple sources example: -// - "header:Authorization,header:X-Api-Key" -func CreateExtractors(lookups string) ([]middleware.ValuesExtractor, error) { - if lookups == "" { - return nil, nil - } - sources := strings.Split(lookups, ",") - var extractors = make([]middleware.ValuesExtractor, 0) - for _, source := range sources { - parts := strings.Split(source, ":") - if len(parts) < 2 { - return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source) - } - - switch parts[0] { - case "query": - extractors = append(extractors, valuesFromQuery(parts[1])) - case "param": - extractors = append(extractors, valuesFromParam(parts[1])) - case "cookie": - extractors = append(extractors, valuesFromCookie(parts[1])) - case "form": - extractors = append(extractors, valuesFromForm(parts[1])) - case "header": - prefix := "" - if len(parts) > 2 { - prefix = parts[2] - } - extractors = append(extractors, valuesFromHeader(parts[1], prefix)) - } - } - return extractors, nil -} - -// valuesFromHeader returns a functions that extracts values from the request header. -// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static -// prefix like `Authorization: ` where part that we want to remove is ` ` -// note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove -// is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `. -// If prefix is left empty the whole value is returned. -func valuesFromHeader(header string, valuePrefix string) middleware.ValuesExtractor { - prefixLen := len(valuePrefix) - // standard library parses http.Request header keys in canonical form but we may provide something else so fix this - header = textproto.CanonicalMIMEHeaderKey(header) - return func(c echo.Context) ([]string, error) { - values := c.Request().Header.Values(header) - if len(values) == 0 { - return nil, errHeaderExtractorValueMissing - } - - result := make([]string, 0) - for i, value := range values { - if prefixLen == 0 { - result = append(result, value) - if i >= extractorLimit-1 { - break - } - continue - } - if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { - result = append(result, value[prefixLen:]) - if i >= extractorLimit-1 { - break - } - } - } - - if len(result) == 0 { - if prefixLen > 0 { - return nil, errHeaderExtractorValueInvalid - } - return nil, errHeaderExtractorValueMissing - } - return result, nil - } -} - -// valuesFromQuery returns a function that extracts values from the query string. -func valuesFromQuery(param string) middleware.ValuesExtractor { - return func(c echo.Context) ([]string, error) { - result := c.QueryParams()[param] - if len(result) == 0 { - return nil, errQueryExtractorValueMissing - } else if len(result) > extractorLimit-1 { - result = result[:extractorLimit] - } - return result, nil - } -} - -// valuesFromParam returns a function that extracts values from the url param string. -func valuesFromParam(param string) middleware.ValuesExtractor { - return func(c echo.Context) ([]string, error) { - result := make([]string, 0) - paramVales := c.ParamValues() - for i, p := range c.ParamNames() { - if param == p { - result = append(result, paramVales[i]) - if i >= extractorLimit-1 { - break - } - } - } - if len(result) == 0 { - return nil, errParamExtractorValueMissing - } - return result, nil - } -} - -// valuesFromCookie returns a function that extracts values from the named cookie. -func valuesFromCookie(name string) middleware.ValuesExtractor { - return func(c echo.Context) ([]string, error) { - cookies := c.Cookies() - if len(cookies) == 0 { - return nil, errCookieExtractorValueMissing - } - - result := make([]string, 0) - for i, cookie := range cookies { - if name == cookie.Name { - result = append(result, cookie.Value) - if i >= extractorLimit-1 { - break - } - } - } - if len(result) == 0 { - return nil, errCookieExtractorValueMissing - } - return result, nil - } -} - -// valuesFromForm returns a function that extracts values from the form field. -func valuesFromForm(name string) middleware.ValuesExtractor { - return func(c echo.Context) ([]string, error) { - if c.Request().Form == nil { - _ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does - } - values := c.Request().Form[name] - if len(values) == 0 { - return nil, errFormExtractorValueMissing - } - if len(values) > extractorLimit-1 { - values = values[:extractorLimit] - } - result := append([]string{}, values...) - return result, nil - } -} diff --git a/extractors_test.go b/extractors_test.go deleted file mode 100644 index 2a13440..0000000 --- a/extractors_test.go +++ /dev/null @@ -1,624 +0,0 @@ -package echojwt - -import ( - "bytes" - "errors" - "fmt" - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" - "mime/multipart" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" -) - -func TestTokenExtractionError_Is(t *testing.T) { - given := echo.ErrUnauthorized.SetInternal(&TokenExtractionError{Err: errCookieExtractorValueMissing}) - - assert.True(t, errors.Is(given, ErrJWTMissing)) - assert.True(t, errors.Is(given, errCookieExtractorValueMissing)) -} - -func TestTokenExtractionError_Error(t *testing.T) { - given := &TokenExtractionError{Err: errCookieExtractorValueMissing} - assert.Equal(t, "missing value in cookies", given.Error()) -} - -func TestTokenExtractionError_Unwrap(t *testing.T) { - given := &TokenExtractionError{Err: errCookieExtractorValueMissing} - assert.Equal(t, errCookieExtractorValueMissing, given.Unwrap()) -} - -type pathParam struct { - name string - value string -} - -func setPathParams(c echo.Context, params []pathParam) { - names := make([]string, 0, len(params)) - values := make([]string, 0, len(params)) - for _, pp := range params { - names = append(names, pp.name) - values = append(values, pp.value) - } - c.SetParamNames(names...) - c.SetParamValues(values...) -} - -func TestCreateExtractors(t *testing.T) { - var testCases = []struct { - name string - givenRequest func() *http.Request - givenPathParams []pathParam - whenLoopups string - expectValues []string - expectCreateError string - expectError string - }{ - { - name: "ok, header", - givenRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAuthorization, "Bearer token") - return req - }, - whenLoopups: "header:Authorization:Bearer ", - expectValues: []string{"token"}, - }, - { - name: "ok, form", - givenRequest: func() *http.Request { - f := make(url.Values) - f.Set("name", "Jon Snow") - - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) - req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) - return req - }, - whenLoopups: "form:name", - expectValues: []string{"Jon Snow"}, - }, - { - name: "ok, cookie", - givenRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderCookie, "_csrf=token") - return req - }, - whenLoopups: "cookie:_csrf", - expectValues: []string{"token"}, - }, - { - name: "ok, param", - givenPathParams: []pathParam{ - {name: "id", value: "123"}, - }, - whenLoopups: "param:id", - expectValues: []string{"123"}, - }, - { - name: "ok, query", - givenRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodGet, "/?id=999", nil) - return req - }, - whenLoopups: "query:id", - expectValues: []string{"999"}, - }, - { - name: "nok, invalid lookup", - whenLoopups: "query", - expectCreateError: "extractor source for lookup could not be split into needed parts: query", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - req := httptest.NewRequest(http.MethodGet, "/", nil) - if tc.givenRequest != nil { - req = tc.givenRequest() - } - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - if tc.givenPathParams != nil { - setPathParams(c, tc.givenPathParams) - } - - extractors, err := CreateExtractors(tc.whenLoopups) - if tc.expectCreateError != "" { - assert.EqualError(t, err, tc.expectCreateError) - return - } - assert.NoError(t, err) - - for _, e := range extractors { - values, eErr := e(c) - assert.Equal(t, tc.expectValues, values) - if tc.expectError != "" { - assert.EqualError(t, eErr, tc.expectError) - return - } - assert.NoError(t, eErr) - } - }) - } -} - -func TestValuesFromHeader(t *testing.T) { - exampleRequest := func(req *http.Request) { - req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") - } - - var testCases = []struct { - name string - givenRequest func(req *http.Request) - whenName string - whenValuePrefix string - expectValues []string - expectError string - }{ - { - name: "ok, single value", - givenRequest: exampleRequest, - whenName: echo.HeaderAuthorization, - whenValuePrefix: "basic ", - expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, - }, - { - name: "ok, single value, case insensitive", - givenRequest: exampleRequest, - whenName: echo.HeaderAuthorization, - whenValuePrefix: "Basic ", - expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, - }, - { - name: "ok, multiple value", - givenRequest: func(req *http.Request) { - req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") - req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") - }, - whenName: echo.HeaderAuthorization, - whenValuePrefix: "basic ", - expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"}, - }, - { - name: "ok, empty prefix", - givenRequest: exampleRequest, - whenName: echo.HeaderAuthorization, - whenValuePrefix: "", - expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="}, - }, - { - name: "nok, no matching due different prefix", - givenRequest: func(req *http.Request) { - req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") - req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") - }, - whenName: echo.HeaderAuthorization, - whenValuePrefix: "Bearer ", - expectError: errHeaderExtractorValueInvalid.Error(), - }, - { - name: "nok, no matching due different prefix", - givenRequest: func(req *http.Request) { - req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") - req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") - }, - whenName: echo.HeaderWWWAuthenticate, - whenValuePrefix: "", - expectError: errHeaderExtractorValueMissing.Error(), - }, - { - name: "nok, no headers", - givenRequest: nil, - whenName: echo.HeaderAuthorization, - whenValuePrefix: "basic ", - expectError: errHeaderExtractorValueMissing.Error(), - }, - { - name: "ok, prefix, cut values over extractorLimit", - givenRequest: func(req *http.Request) { - for i := 1; i <= 25; i++ { - req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i)) - } - }, - whenName: echo.HeaderAuthorization, - whenValuePrefix: "basic ", - expectValues: []string{ - "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", - "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", - }, - }, - { - name: "ok, cut values over extractorLimit", - givenRequest: func(req *http.Request) { - for i := 1; i <= 25; i++ { - req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i)) - } - }, - whenName: echo.HeaderAuthorization, - whenValuePrefix: "", - expectValues: []string{ - "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", - "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - req := httptest.NewRequest(http.MethodGet, "/", nil) - if tc.givenRequest != nil { - tc.givenRequest(req) - } - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix) - - values, err := extractor(c) - assert.Equal(t, tc.expectValues, values) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValuesFromQuery(t *testing.T) { - var testCases = []struct { - name string - givenQueryPart string - whenName string - expectValues []string - expectError string - }{ - { - name: "ok, single value", - givenQueryPart: "?id=123&name=test", - whenName: "id", - expectValues: []string{"123"}, - }, - { - name: "ok, multiple value", - givenQueryPart: "?id=123&id=456&name=test", - whenName: "id", - expectValues: []string{"123", "456"}, - }, - { - name: "nok, missing value", - givenQueryPart: "?id=123&name=test", - whenName: "nope", - expectError: errQueryExtractorValueMissing.Error(), - }, - { - name: "ok, cut values over extractorLimit", - givenQueryPart: "?name=test" + - "&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" + - "&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" + - "&id=21&id=22&id=23&id=24&id=25", - whenName: "id", - expectValues: []string{ - "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", - "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - extractor := valuesFromQuery(tc.whenName) - - values, err := extractor(c) - assert.Equal(t, tc.expectValues, values) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValuesFromParam(t *testing.T) { - examplePathParams := []pathParam{ - {name: "id", value: "123"}, - {name: "gid", value: "456"}, - {name: "gid", value: "789"}, - } - examplePathParams20 := make([]pathParam, 0) - for i := 1; i < 25; i++ { - examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)}) - } - - var testCases = []struct { - name string - givenPathParams []pathParam - whenName string - expectValues []string - expectError string - }{ - { - name: "ok, single value", - givenPathParams: examplePathParams, - whenName: "id", - expectValues: []string{"123"}, - }, - { - name: "ok, multiple value", - givenPathParams: examplePathParams, - whenName: "gid", - expectValues: []string{"456", "789"}, - }, - { - name: "nok, no values", - givenPathParams: nil, - whenName: "nope", - expectValues: nil, - expectError: errParamExtractorValueMissing.Error(), - }, - { - name: "nok, no matching value", - givenPathParams: examplePathParams, - whenName: "nope", - expectValues: nil, - expectError: errParamExtractorValueMissing.Error(), - }, - { - name: "ok, cut values over extractorLimit", - givenPathParams: examplePathParams20, - whenName: "id", - expectValues: []string{ - "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", - "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - if tc.givenPathParams != nil { - setPathParams(c, tc.givenPathParams) - } - - extractor := valuesFromParam(tc.whenName) - - values, err := extractor(c) - assert.Equal(t, tc.expectValues, values) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValuesFromCookie(t *testing.T) { - exampleRequest := func(req *http.Request) { - req.Header.Set(echo.HeaderCookie, "_csrf=token") - } - - var testCases = []struct { - name string - givenRequest func(req *http.Request) - whenName string - expectValues []string - expectError string - }{ - { - name: "ok, single value", - givenRequest: exampleRequest, - whenName: "_csrf", - expectValues: []string{"token"}, - }, - { - name: "ok, multiple value", - givenRequest: func(req *http.Request) { - req.Header.Add(echo.HeaderCookie, "_csrf=token") - req.Header.Add(echo.HeaderCookie, "_csrf=token2") - }, - whenName: "_csrf", - expectValues: []string{"token", "token2"}, - }, - { - name: "nok, no matching cookie", - givenRequest: exampleRequest, - whenName: "xxx", - expectValues: nil, - expectError: errCookieExtractorValueMissing.Error(), - }, - { - name: "nok, no cookies at all", - givenRequest: nil, - whenName: "xxx", - expectValues: nil, - expectError: errCookieExtractorValueMissing.Error(), - }, - { - name: "ok, cut values over extractorLimit", - givenRequest: func(req *http.Request) { - for i := 1; i < 25; i++ { - req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i)) - } - }, - whenName: "_csrf", - expectValues: []string{ - "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", - "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - req := httptest.NewRequest(http.MethodGet, "/", nil) - if tc.givenRequest != nil { - tc.givenRequest(req) - } - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - extractor := valuesFromCookie(tc.whenName) - - values, err := extractor(c) - assert.Equal(t, tc.expectValues, values) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValuesFromForm(t *testing.T) { - examplePostFormRequest := func(mod func(v *url.Values)) *http.Request { - f := make(url.Values) - f.Set("name", "Jon Snow") - f.Set("emails[]", "jon@labstack.com") - if mod != nil { - mod(&f) - } - - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) - req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) - - return req - } - exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request { - f := make(url.Values) - f.Set("name", "Jon Snow") - f.Set("emails[]", "jon@labstack.com") - if mod != nil { - mod(&f) - } - - req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil) - return req - } - - exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request { - var b bytes.Buffer - w := multipart.NewWriter(&b) - w.WriteField("name", "Jon Snow") - w.WriteField("emails[]", "jon@labstack.com") - if mod != nil { - mod(w) - } - - fw, _ := w.CreateFormFile("upload", "my.file") - fw.Write([]byte(`
hi
`)) - w.Close() - - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(b.String())) - req.Header.Add(echo.HeaderContentType, w.FormDataContentType()) - - return req - } - - var testCases = []struct { - name string - givenRequest *http.Request - whenName string - expectValues []string - expectError string - }{ - { - name: "ok, POST form, single value", - givenRequest: examplePostFormRequest(nil), - whenName: "emails[]", - expectValues: []string{"jon@labstack.com"}, - }, - { - name: "ok, POST form, multiple value", - givenRequest: examplePostFormRequest(func(v *url.Values) { - v.Add("emails[]", "snow@labstack.com") - }), - whenName: "emails[]", - expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, - }, - { - name: "ok, POST multipart/form, multiple value", - givenRequest: exampleMultiPartFormRequest(func(w *multipart.Writer) { - w.WriteField("emails[]", "snow@labstack.com") - }), - whenName: "emails[]", - expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, - }, - { - name: "ok, GET form, single value", - givenRequest: exampleGetFormRequest(nil), - whenName: "emails[]", - expectValues: []string{"jon@labstack.com"}, - }, - { - name: "ok, GET form, multiple value", - givenRequest: examplePostFormRequest(func(v *url.Values) { - v.Add("emails[]", "snow@labstack.com") - }), - whenName: "emails[]", - expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, - }, - { - name: "nok, POST form, value missing", - givenRequest: examplePostFormRequest(nil), - whenName: "nope", - expectError: errFormExtractorValueMissing.Error(), - }, - { - name: "ok, cut values over extractorLimit", - givenRequest: examplePostFormRequest(func(v *url.Values) { - for i := 1; i < 25; i++ { - v.Add("id[]", fmt.Sprintf("%v", i)) - } - }), - whenName: "id[]", - expectValues: []string{ - "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", - "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - req := tc.givenRequest - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - extractor := valuesFromForm(tc.whenName) - - values, err := extractor(c) - assert.Equal(t, tc.expectValues, values) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} diff --git a/go.mod b/go.mod index b2c46c7..a1177e7 100644 --- a/go.mod +++ b/go.mod @@ -1,25 +1,16 @@ -module github.com/labstack/echo-jwt/v4 +module github.com/labstack/echo-jwt/v5 -go 1.24.0 +go 1.25.0 require ( github.com/golang-jwt/jwt/v5 v5.3.0 - github.com/labstack/echo/v4 v4.13.4 + github.com/labstack/echo/v5 v5.0.0 github.com/stretchr/testify v1.11.1 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/labstack/gommon v0.4.2 // indirect - github.com/mattn/go-colorable v0.1.14 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/fasttemplate v1.2.2 // indirect - golang.org/x/crypto v0.45.0 // indirect - golang.org/x/net v0.47.0 // indirect - golang.org/x/sys v0.38.0 // indirect - golang.org/x/text v0.31.0 // indirect golang.org/x/time v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ba0b2a5..d6a4f56 100644 --- a/go.sum +++ b/go.sum @@ -2,33 +2,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= -github.com/labstack/echo/v4 v4.13.4 h1:oTZZW+T3s9gAu5L8vmzihV7/lkXGZuITzTQkTEhcXEA= -github.com/labstack/echo/v4 v4.13.4/go.mod h1:g63b33BZ5vZzcIUF8AtRH40DrTlXnx4UMC8rBdndmjQ= -github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= -github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= -github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= -github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/labstack/echo/v5 v5.0.0 h1:JHKGrI0cbNsNMyKvranuY0C94O4hSM7yc/HtwcV3Na4= +github.com/labstack/echo/v5 v5.0.0/go.mod h1:SyvlSdObGjRXeQfCCXW/sybkZdOOQZBmpKF0bvALaeo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= -github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= -golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= -golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= -golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= -golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/jwt.go b/jwt.go index 2b6b48b..cd19b8a 100644 --- a/jwt.go +++ b/jwt.go @@ -9,8 +9,8 @@ import ( "net/http" "github.com/golang-jwt/jwt/v5" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" ) // Config defines the config for JWT middleware. @@ -22,7 +22,9 @@ type Config struct { BeforeFunc middleware.BeforeFunc // SuccessHandler defines a function which is executed for a valid token. - SuccessHandler func(c echo.Context) + // In case SuccessHandler error the middleware stops handler chain execution and + // returns error. + SuccessHandler func(c *echo.Context) error // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. @@ -31,7 +33,7 @@ type Config struct { // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users // In that case you can use ErrorHandler to set default public JWT token value to request and continue with handler chain. - ErrorHandler func(c echo.Context, err error) error + ErrorHandler func(c *echo.Context, err error) error // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to // ignore the error (by returning `nil`). @@ -101,12 +103,12 @@ type Config struct { // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token // parsing fails or parsed token is invalid. // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library - ParseTokenFunc func(c echo.Context, auth string) (interface{}, error) + ParseTokenFunc func(c *echo.Context, auth string) (interface{}, error) // Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. // Not used if custom ParseTokenFunc is set. // Optional. Defaults to function returning jwt.MapClaims - NewClaimsFunc func(c echo.Context) jwt.Claims + NewClaimsFunc func(c *echo.Context) jwt.Claims } const ( @@ -134,6 +136,19 @@ func (e TokenParsingError) Is(target error) bool { return target == ErrJWTInvali func (e *TokenParsingError) Error() string { return e.Err.Error() } func (e *TokenParsingError) Unwrap() error { return e.Err } +// TokenExtractionError is catch all type for all errors that occur when the token is extracted from the request. This +// helps to distinguish extractor errors from token parsing errors even if custom extractors or token parsing functions +// are being used that have their own custom errors. +type TokenExtractionError struct { + Err error +} + +// Is checks if target error is same as TokenExtractionError +func (e TokenExtractionError) Is(target error) bool { return target == ErrJWTMissing } // to provide some compatibility with older error handling logic + +func (e *TokenExtractionError) Error() string { return e.Err.Error() } +func (e *TokenExtractionError) Unwrap() error { return e.Err } + // TokenError is used to return error with error occurred JWT token when processing JWT token type TokenError struct { Token *jwt.Token @@ -184,7 +199,7 @@ func (config Config) ToMiddleware() (echo.MiddlewareFunc, error) { } if config.NewClaimsFunc == nil { - config.NewClaimsFunc = func(c echo.Context) jwt.Claims { + config.NewClaimsFunc = func(c *echo.Context) jwt.Claims { return jwt.MapClaims{} } } @@ -197,7 +212,7 @@ func (config Config) ToMiddleware() (echo.MiddlewareFunc, error) { if config.ParseTokenFunc == nil { config.ParseTokenFunc = config.defaultParseTokenFunc } - extractors, ceErr := CreateExtractors(config.TokenLookup) + extractors, ceErr := middleware.CreateExtractors(config.TokenLookup, 1) if ceErr != nil { return nil, ceErr } @@ -206,7 +221,7 @@ func (config Config) ToMiddleware() (echo.MiddlewareFunc, error) { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -217,7 +232,7 @@ func (config Config) ToMiddleware() (echo.MiddlewareFunc, error) { var lastExtractorErr error var lastTokenErr error for _, extractor := range extractors { - auths, extrErr := extractor(c) + auths, _, extrErr := extractor(c) if extrErr != nil { lastExtractorErr = extrErr continue @@ -231,7 +246,9 @@ func (config Config) ToMiddleware() (echo.MiddlewareFunc, error) { // Store user information from token into context. c.Set(config.ContextKey, token) if config.SuccessHandler != nil { - config.SuccessHandler(c) + if sErr := config.SuccessHandler(c); sErr != nil { + return sErr + } } return next(c) } @@ -254,10 +271,10 @@ func (config Config) ToMiddleware() (echo.MiddlewareFunc, error) { } if lastTokenErr == nil { - return ErrJWTMissing.WithInternal(err) + return ErrJWTMissing.Wrap(err) } - return ErrJWTInvalid.WithInternal(err) + return ErrJWTInvalid.Wrap(err) } }, nil } @@ -284,7 +301,7 @@ func (config Config) defaultKeyFunc(token *jwt.Token) (interface{}, error) { // defaultParseTokenFunc creates JWTGo implementation for ParseTokenFunc. // // error returns TokenError. -func (config Config) defaultParseTokenFunc(c echo.Context, auth string) (interface{}, error) { +func (config Config) defaultParseTokenFunc(c *echo.Context, auth string) (interface{}, error) { token, err := jwt.ParseWithClaims(auth, config.NewClaimsFunc(c), config.KeyFunc) if err != nil { return nil, &TokenError{Token: token, Err: err} diff --git a/jwt_benchmark_test.go b/jwt_benchmark_test.go index cd6f795..43c52f2 100644 --- a/jwt_benchmark_test.go +++ b/jwt_benchmark_test.go @@ -1,17 +1,18 @@ package echojwt import ( - "github.com/golang-jwt/jwt/v5" - "github.com/labstack/echo/v4" "net/http" "net/http/httptest" "testing" + + "github.com/golang-jwt/jwt/v5" + "github.com/labstack/echo/v5" ) func BenchmarkJWTSuccessPath(b *testing.B) { e := echo.New() - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { token := c.Get("user").(*jwt.Token) return c.JSON(http.StatusTeapot, token.Claims) }) @@ -40,7 +41,7 @@ func BenchmarkJWTSuccessPath(b *testing.B) { func BenchmarkJWTErrorPath(b *testing.B) { e := echo.New() - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { token := c.Get("user").(*jwt.Token) return c.JSON(http.StatusTeapot, token.Claims) }) diff --git a/jwt_external_test.go b/jwt_external_test.go index e38df49..9fa3add 100644 --- a/jwt_external_test.go +++ b/jwt_external_test.go @@ -6,14 +6,15 @@ package echojwt_test import ( "errors" "fmt" - "github.com/golang-jwt/jwt/v5" - echojwt "github.com/labstack/echo-jwt/v4" - "github.com/labstack/echo/v4" "io" "log" "net" "net/http" "time" + + "github.com/golang-jwt/jwt/v5" + echojwt "github.com/labstack/echo-jwt/v5" + "github.com/labstack/echo/v5" ) func ExampleWithConfig_usage() { @@ -23,7 +24,7 @@ func ExampleWithConfig_usage() { SigningKey: []byte("secret"), })) - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { // make sure that your imports are correct versions. for example if you use `"github.com/golang-jwt/jwt"` as // import this cast will fail and `"github.com/golang-jwt/jwt/v5"` will succeed. // Although `.(*jwt.Token)` looks exactly the same for both packages but this struct is still different diff --git a/jwt_integration_test.go b/jwt_integration_test.go index 25732bc..f722b29 100644 --- a/jwt_integration_test.go +++ b/jwt_integration_test.go @@ -5,12 +5,13 @@ package echojwt_test import ( "errors" - "github.com/golang-jwt/jwt/v5" - echojwt "github.com/labstack/echo-jwt/v4" - "github.com/labstack/echo/v4" "net/http" "net/http/httptest" "testing" + + "github.com/golang-jwt/jwt/v5" + echojwt "github.com/labstack/echo-jwt/v5" + "github.com/labstack/echo/v5" ) func TestIntegrationMiddlewareWithHandler(t *testing.T) { @@ -32,7 +33,7 @@ func TestIntegrationMiddlewareWithHandler(t *testing.T) { } } -func exampleHandler(c echo.Context) error { +func exampleHandler(c *echo.Context) error { // make sure that your imports are correct versions. for example if you use `"github.com/golang-jwt/jwt"` as // import this cast will fail and `"github.com/golang-jwt/jwt/v5"` will succeed. // Although `.(*jwt.Token)` looks exactly the same for both packages but this struct is still different diff --git a/jwt_test.go b/jwt_test.go index 1b47494..eb071f5 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -13,14 +13,14 @@ import ( "testing" "github.com/golang-jwt/jwt/v5" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" "github.com/stretchr/testify/assert" ) func TestTokenParsingError_Is(t *testing.T) { err := errors.New("parsing error") - given := echo.ErrUnauthorized.SetInternal(&TokenParsingError{Err: err}) + given := echo.ErrUnauthorized.Wrap(&TokenParsingError{Err: err}) assert.True(t, errors.Is(given, ErrJWTInvalid)) assert.True(t, errors.Is(given, err)) @@ -52,7 +52,7 @@ type jwtCustomClaims struct { func TestJWT(t *testing.T) { e := echo.New() - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { token := c.Get("user").(*jwt.Token) return c.JSON(http.StatusOK, token.Claims) }) @@ -71,7 +71,7 @@ func TestJWT(t *testing.T) { func TestJWT_combinations(t *testing.T) { e := echo.New() - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" @@ -110,7 +110,7 @@ func TestJWT_combinations(t *testing.T) { SigningKey: validKey, SigningMethod: "RS256", }, - expectError: "code=401, message=invalid or expired jwt, internal=token is unverifiable: error while executing keyfunc: unexpected jwt signing method=HS256", + expectError: "code=401, message=invalid or expired jwt, err=token is unverifiable: error while executing keyfunc: unexpected jwt signing method=HS256", }, { name: "Invalid key", @@ -118,7 +118,7 @@ func TestJWT_combinations(t *testing.T) { config: Config{ SigningKey: invalidKey, }, - expectError: "code=401, message=invalid or expired jwt, internal=token signature is invalid: signature is invalid", + expectError: "code=401, message=invalid or expired jwt, err=token signature is invalid: signature is invalid", }, { name: "Valid JWT", @@ -140,7 +140,7 @@ func TestJWT_combinations(t *testing.T) { hdrAuth: validAuth, config: Config{ SigningKey: []byte("secret"), - NewClaimsFunc: func(c echo.Context) jwt.Claims { + NewClaimsFunc: func(c *echo.Context) jwt.Claims { return &jwtCustomClaims{ // this needs to be pointer to json unmarshalling to work jwtCustomInfo: jwtCustomInfo{ Name: "John Doe", @@ -156,14 +156,14 @@ func TestJWT_combinations(t *testing.T) { config: Config{ SigningKey: validKey, }, - expectError: "code=401, message=missing or malformed jwt, internal=invalid value in request header", + expectError: "code=401, message=missing or malformed jwt, err=invalid value in request header", }, { name: "Empty header auth field", config: Config{ SigningKey: validKey, }, - expectError: "code=401, message=missing or malformed jwt, internal=invalid value in request header", + expectError: "code=401, message=missing or malformed jwt, err=invalid value in request header", }, { name: "Valid query method", @@ -180,7 +180,7 @@ func TestJWT_combinations(t *testing.T) { TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwtxyz=" + token, - expectError: "code=401, message=missing or malformed jwt, internal=missing value in the query string", + expectError: "code=401, message=missing or malformed jwt, err=missing value in the query string", }, { name: "Invalid query param value", @@ -189,7 +189,7 @@ func TestJWT_combinations(t *testing.T) { TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwt=invalid-token", - expectError: "code=401, message=invalid or expired jwt, internal=token is malformed: token contains an invalid number of segments", + expectError: "code=401, message=invalid or expired jwt, err=token is malformed: token contains an invalid number of segments", }, { name: "Empty query", @@ -198,7 +198,7 @@ func TestJWT_combinations(t *testing.T) { TokenLookup: "query:jwt", }, reqURL: "/?a=b", - expectError: "code=401, message=missing or malformed jwt, internal=missing value in the query string", + expectError: "code=401, message=missing or malformed jwt, err=missing value in the query string", }, { config: Config{ @@ -231,7 +231,7 @@ func TestJWT_combinations(t *testing.T) { TokenLookup: "cookie:jwt", }, hdrCookie: "jwt=invalid", - expectError: "code=401, message=invalid or expired jwt, internal=token is malformed: token contains an invalid number of segments", + expectError: "code=401, message=invalid or expired jwt, err=token is malformed: token contains an invalid number of segments", }, { name: "Empty cookie", @@ -239,7 +239,7 @@ func TestJWT_combinations(t *testing.T) { SigningKey: validKey, TokenLookup: "cookie:jwt", }, - expectError: "code=401, message=missing or malformed jwt, internal=missing value in cookies", + expectError: "code=401, message=missing or malformed jwt, err=missing value in cookies", }, { name: "Valid form method", @@ -256,7 +256,7 @@ func TestJWT_combinations(t *testing.T) { TokenLookup: "form:jwt", }, formValues: map[string]string{"jwt": "invalid"}, - expectError: "code=401, message=invalid or expired jwt, internal=token is malformed: token contains an invalid number of segments", + expectError: "code=401, message=invalid or expired jwt, err=token is malformed: token contains an invalid number of segments", }, { name: "Empty form field", @@ -264,7 +264,7 @@ func TestJWT_combinations(t *testing.T) { SigningKey: validKey, TokenLookup: "form:jwt", }, - expectError: "code=401, message=missing or malformed jwt, internal=missing value in the form", + expectError: "code=401, message=missing or malformed jwt, err=missing value in the form", }, } @@ -292,8 +292,9 @@ func TestJWT_combinations(t *testing.T) { c := e.NewContext(req, res) if tc.reqURL == "/"+token { - c.SetParamNames("jwt") - c.SetParamValues(token) + c.SetPathValues(echo.PathValues{ + {Name: "jwt", Value: token}, + }) } mw, err := tc.config.ToMiddleware() @@ -327,7 +328,7 @@ func TestJWT_combinations(t *testing.T) { func TestJWTwithKID(t *testing.T) { e := echo.New() - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } @@ -422,14 +423,14 @@ func TestConfig_skipper(t *testing.T) { e := echo.New() e.Use(WithConfig(Config{ - Skipper: func(context echo.Context) bool { + Skipper: func(context *echo.Context) bool { return true // skip everything }, SigningKey: []byte("secret"), })) isCalled := false - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { isCalled = true return c.String(http.StatusTeapot, "test") }) @@ -444,13 +445,13 @@ func TestConfig_skipper(t *testing.T) { func TestConfig_BeforeFunc(t *testing.T) { e := echo.New() - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { return c.String(http.StatusTeapot, "test") }) isCalled := false e.Use(WithConfig(Config{ - BeforeFunc: func(context echo.Context) { + BeforeFunc: func(context *echo.Context) { isCalled = true }, SigningKey: []byte("secret"), @@ -476,7 +477,7 @@ func TestConfig_ErrorHandling(t *testing.T) { name: "ok, ErrorHandler is executed", given: Config{ SigningKey: []byte("secret"), - ErrorHandler: func(c echo.Context, err error) error { + ErrorHandler: func(c *echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "custom_error") }, }, @@ -486,7 +487,7 @@ func TestConfig_ErrorHandling(t *testing.T) { name: "ok, extractor errors are distinguishable as TokenExtractionError", given: Config{ SigningKey: []byte("secret"), - ErrorHandler: func(c echo.Context, err error) error { + ErrorHandler: func(c *echo.Context, err error) error { var extratorErr *TokenExtractionError if !errors.As(err, &extratorErr) { panic("must get TokenExtractionError") @@ -500,7 +501,7 @@ func TestConfig_ErrorHandling(t *testing.T) { name: "ok, token parsing errors are distinguishable as TokenParsingError", given: Config{ SigningKey: []byte("secret"), - ErrorHandler: func(c echo.Context, err error) error { + ErrorHandler: func(c *echo.Context, err error) error { var tpErr *TokenParsingError if !errors.As(err, &tpErr) { panic("must get TokenParsingError") @@ -520,7 +521,7 @@ func TestConfig_ErrorHandling(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() - h := func(c echo.Context) error { + h := func(c *echo.Context) error { return c.String(http.StatusNotImplemented, "should not end up here") } @@ -550,7 +551,7 @@ func TestConfig_parseTokenErrorHandling(t *testing.T) { { name: "ok, ErrorHandler is executed", given: Config{ - ErrorHandler: func(c echo.Context, err error) error { + ErrorHandler: func(c *echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error()) }, }, @@ -562,13 +563,13 @@ func TestConfig_parseTokenErrorHandling(t *testing.T) { t.Run(tc.name, func(t *testing.T) { e := echo.New() //e.Debug = true - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { return c.String(http.StatusNotImplemented, "should not end up here") }) config := tc.given parseTokenCalled := false - config.ParseTokenFunc = func(c echo.Context, auth string) (interface{}, error) { + config.ParseTokenFunc = func(c *echo.Context, auth string) (interface{}, error) { parseTokenCalled = true return nil, errors.New("parsing failed") } @@ -589,7 +590,7 @@ func TestConfig_parseTokenErrorHandling(t *testing.T) { func TestConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { e := echo.New() - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { return c.String(http.StatusTeapot, "test") }) @@ -598,7 +599,7 @@ func TestConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { signingKey := []byte("secret") config := Config{ - ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { + ParseTokenFunc: func(c *echo.Context, auth string) (interface{}, error) { keyFunc := func(t *jwt.Token) (interface{}, error) { if t.Method.Alg() != "HS256" { return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) @@ -631,18 +632,19 @@ func TestConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { func TestMustJWTWithConfig_SuccessHandler(t *testing.T) { e := echo.New() - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { success := c.Get("success").(string) user := c.Get("user").(string) return c.String(http.StatusTeapot, fmt.Sprintf("%v:%v", success, user)) }) mw, err := Config{ - ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { + ParseTokenFunc: func(c *echo.Context, auth string) (interface{}, error) { return auth, nil }, - SuccessHandler: func(c echo.Context) { + SuccessHandler: func(c *echo.Context) error { c.Set("success", "yes") + return nil }, }.ToMiddleware() assert.NoError(t, err) @@ -657,11 +659,38 @@ func TestMustJWTWithConfig_SuccessHandler(t *testing.T) { assert.Equal(t, http.StatusTeapot, res.Code) } +func TestMustJWTWithConfig_SuccessHandlerError(t *testing.T) { + e := echo.New() + + e.GET("/", func(c *echo.Context) error { + return c.String(http.StatusTeapot, "should not end up here") + }) + + mw, err := Config{ + ParseTokenFunc: func(c *echo.Context, auth string) (interface{}, error) { + return auth, nil + }, + SuccessHandler: func(c *echo.Context) error { + return echo.ErrForbidden.Wrap(errors.New("nope")) + }, + }.ToMiddleware() + assert.NoError(t, err) + e.Use(mw) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderAuthorization, "Bearer valid_token_base64") + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, "{\"message\":\"Forbidden\"}\n", res.Body.String()) + assert.Equal(t, http.StatusForbidden, res.Code) +} + func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { var testCases = []struct { name string givenContinueOnIgnoredError bool - givenErrorHandler func(c echo.Context, err error) error + givenErrorHandler func(c *echo.Context, err error) error givenTokenLookup string whenAuthHeaders []string whenCookies []string @@ -674,7 +703,7 @@ func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { { name: "ok, with valid JWT from auth header", givenContinueOnIgnoredError: true, - givenErrorHandler: func(c echo.Context, err error) error { + givenErrorHandler: func(c *echo.Context, err error) error { return nil }, whenAuthHeaders: []string{"Bearer valid_token_base64"}, @@ -685,7 +714,7 @@ func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { { name: "ok, missing header, callNext and set public_token from error handler", givenContinueOnIgnoredError: true, - givenErrorHandler: func(c echo.Context, err error) error { + givenErrorHandler: func(c *echo.Context, err error) error { var extratorErr *TokenExtractionError if !errors.As(err, &extratorErr) { panic("must get TokenExtractionError") @@ -700,7 +729,7 @@ func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { { name: "ok, invalid token, callNext and set public_token from error handler", givenContinueOnIgnoredError: true, - givenErrorHandler: func(c echo.Context, err error) error { + givenErrorHandler: func(c *echo.Context, err error) error { // this is probably not realistic usecase. on parse error you probably want to return error if err.Error() != "parser_error" { panic("must get parser_error") @@ -716,7 +745,7 @@ func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { { name: "nok, invalid token, return error from error handler", givenContinueOnIgnoredError: true, - givenErrorHandler: func(c echo.Context, err error) error { + givenErrorHandler: func(c *echo.Context, err error) error { if err.Error() != "parser_error" { panic("must get parser_error") } @@ -730,7 +759,7 @@ func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { { name: "nok, ContinueOnIgnoredError but return error from error handler", givenContinueOnIgnoredError: true, - givenErrorHandler: func(c echo.Context, err error) error { + givenErrorHandler: func(c *echo.Context, err error) error { return echo.ErrUnauthorized }, whenAuthHeaders: []string{}, // no JWT header @@ -740,7 +769,7 @@ func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { { name: "nok, ContinueOnIgnoredError=false", givenContinueOnIgnoredError: false, - givenErrorHandler: func(c echo.Context, err error) error { + givenErrorHandler: func(c *echo.Context, err error) error { return echo.ErrUnauthorized }, whenAuthHeaders: []string{}, // no JWT header @@ -753,7 +782,7 @@ func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { t.Run(tc.name, func(t *testing.T) { e := echo.New() - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { token := c.Get("user").(string) return c.String(http.StatusTeapot, token) }) @@ -761,7 +790,7 @@ func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { mw, err := Config{ ContinueOnIgnoredError: tc.givenContinueOnIgnoredError, TokenLookup: tc.givenTokenLookup, - ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { + ParseTokenFunc: func(c *echo.Context, auth string) (interface{}, error) { return tc.whenParseReturn, tc.whenParseError }, ErrorHandler: tc.givenErrorHandler, @@ -785,7 +814,7 @@ func TestJWTWithConfig_ContinueOnIgnoredError(t *testing.T) { func TestConfig_TokenLookupFuncs(t *testing.T) { e := echo.New() - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { token := c.Get("user").(*jwt.Token) return c.JSON(http.StatusOK, token.Claims) }) @@ -793,8 +822,8 @@ func TestConfig_TokenLookupFuncs(t *testing.T) { e.Use(WithConfig(Config{ SigningKey: []byte("secret"), TokenLookupFuncs: []middleware.ValuesExtractor{ - func(c echo.Context) ([]string, error) { - return []string{c.Request().Header.Get("X-API-Key")}, nil + func(c *echo.Context) ([]string, middleware.ExtractorSource, error) { + return []string{c.Request().Header.Get("X-API-Key")}, middleware.ExtractorSourceHeader, nil }, }, })) @@ -841,7 +870,7 @@ func TestDataRacesOnParallelExecution(t *testing.T) { } e := echo.New() - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { token := c.Get("user").(*jwt.Token) return c.JSON(http.StatusTeapot, token.Claims) })