Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 58 additions & 4 deletions middleware/rate_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,33 @@ import (
"errors"
"math"
"net/http"
"strconv"
"sync"
"time"

"github.com/labstack/echo/v5"
"golang.org/x/time/rate"
)

// Rate limit response headers set by stores that implement RateLimiterStoreContext.
const (
HeaderXRateLimitLimit = "X-RateLimit-Limit"
HeaderXRateLimitRemaining = "X-RateLimit-Remaining"
)

// RateLimiterStore is the interface to be implemented by custom stores.
type RateLimiterStore interface {
Allow(identifier string) (bool, error)
}

// RateLimiterStoreContext is an optional interface a RateLimiterStore may implement.
// When the configured store implements it, the rate limiter calls AllowContext
// (with the request context) instead of Allow, allowing the store to set response
// headers such as Retry-After or X-RateLimit-* on the allow/deny decision.
type RateLimiterStoreContext interface {
AllowContext(c *echo.Context, identifier string) (bool, error)
}

// RateLimiterConfig defines the configuration for the rate limiter
type RateLimiterConfig struct {
Skipper Skipper
Expand Down Expand Up @@ -136,7 +151,14 @@ func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
return config.ErrorHandler(c, err)
}

if allow, allowErr := config.Store.Allow(identifier); !allow {
var allow bool
var allowErr error
if sc, ok := config.Store.(RateLimiterStoreContext); ok {
allow, allowErr = sc.AllowContext(c, identifier)
} else {
allow, allowErr = config.Store.Allow(identifier)
}
if !allow {
return config.DenyHandler(c, identifier, allowErr)
}
return next(c)
Expand Down Expand Up @@ -232,7 +254,22 @@ var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{

// Allow implements RateLimiterStore.Allow
func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
_, allowed := store.allow(identifier)
return allowed, nil
}

// AllowContext implements RateLimiterStoreContext: it makes the allow/deny decision
// and sets the X-RateLimit-* (and Retry-After when denied) response headers.
func (store *RateLimiterMemoryStore) AllowContext(c *echo.Context, identifier string) (bool, error) {
limiter, allowed := store.allow(identifier)
store.setRateLimitHeaders(c, limiter, allowed)
return allowed, nil
}

func (store *RateLimiterMemoryStore) allow(identifier string) (*rate.Limiter, bool) {
store.mutex.Lock()
defer store.mutex.Unlock()

limiter, exists := store.visitors[identifier]
if !exists {
limiter = new(Visitor)
Expand All @@ -244,9 +281,26 @@ func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
if now.Sub(store.lastCleanup) > store.expiresIn {
store.cleanupStaleVisitors(now)
}
allowed := limiter.AllowN(now, 1)
store.mutex.Unlock()
return allowed, nil
return limiter.Limiter, limiter.AllowN(now, 1)
}

func (store *RateLimiterMemoryStore) setRateLimitHeaders(c *echo.Context, limiter *rate.Limiter, allowed bool) {
header := c.Response().Header()
header.Set(HeaderXRateLimitLimit, strconv.Itoa(store.burst))

remaining := int(math.Floor(limiter.Tokens()))
if remaining < 0 {
remaining = 0
}
header.Set(HeaderXRateLimitRemaining, strconv.Itoa(remaining))

if !allowed {
reservation := limiter.ReserveN(store.timeNow(), 1)
if delay := reservation.Delay(); delay > 0 {
header.Set(echo.HeaderRetryAfter, strconv.Itoa(int(math.Ceil(delay.Seconds()))))
}
reservation.Cancel()
}
}

/*
Expand Down
89 changes: 89 additions & 0 deletions middleware/rate_limiter_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors

package middleware

import (
"net/http"
"net/http/httptest"
"strconv"
"testing"

"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)

// ctxAwareStore implements both Allow and the optional AllowContext. AllowContext
// gives the store the request context so it can set response headers (e.g.
// Retry-After / X-RateLimit-*) — see #2961.
type ctxAwareStore struct {
allowCalled bool
ctxAllowCalled bool
allow bool
}

func (s *ctxAwareStore) Allow(identifier string) (bool, error) {
s.allowCalled = true
return s.allow, nil
}

func (s *ctxAwareStore) AllowContext(c *echo.Context, identifier string) (bool, error) {
s.ctxAllowCalled = true
c.Response().Header().Set("Retry-After", "42")
return s.allow, nil
}

// When the store implements AllowContext, the middleware must call it instead of
// Allow, so the store can set rate-limit headers on the response.
func TestRateLimiter_storeAllowContextIsPreferred(t *testing.T) {
e := echo.New()
store := &ctxAwareStore{allow: true}
mw := RateLimiterWithConfig(RateLimiterConfig{
Store: store,
IdentifierExtractor: func(c *echo.Context) (string, error) { return "id", nil },
})
handler := mw(func(c *echo.Context) error { return c.String(http.StatusOK, "ok") })

req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

assert.NoError(t, handler(c))
assert.True(t, store.ctxAllowCalled, "AllowContext should be called when implemented")
assert.False(t, store.allowCalled, "Allow should not be called when AllowContext is implemented")
assert.Equal(t, "42", rec.Header().Get("Retry-After"), "store should be able to set headers via the context")
}

// The built-in memory store implements AllowContext, so it sets X-RateLimit-Limit /
// X-RateLimit-Remaining on every request and Retry-After when the limit is hit (#2961).
func TestRateLimiterMemoryStore_AllowContextSetsHeaders(t *testing.T) {
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
e := echo.New()
e.GET("/", func(c *echo.Context) error { return c.String(http.StatusOK, "ok") },
RateLimiterWithConfig(RateLimiterConfig{
Store: store,
IdentifierExtractor: func(c *echo.Context) (string, error) { return "id", nil },
}))

do := func() *httptest.ResponseRecorder {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
return rec
}

// Burst of 3: each allowed request advertises the limit and decreasing remaining.
for i := 0; i < 3; i++ {
rec := do()
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "3", rec.Header().Get(HeaderXRateLimitLimit))
assert.Equal(t, strconv.Itoa(2-i), rec.Header().Get(HeaderXRateLimitRemaining))
assert.Empty(t, rec.Header().Get(echo.HeaderRetryAfter))
}

// 4th request is denied: 429, remaining 0, and a Retry-After hint.
rec := do()
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
assert.Equal(t, "0", rec.Header().Get(HeaderXRateLimitRemaining))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderRetryAfter))
}
Loading