Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions middleware/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,14 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if req.Header.Get(echo.HeaderXForwardedProto) == "" {
req.Header.Set(echo.HeaderXForwardedProto, c.Scheme())
}
if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy.
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
if c.IsWebSocket() { // For HTTP, this is set by Go HTTP reverse proxy.
// Append, not set, to preserve the incoming chain from upstream proxies.
prior := req.Header[echo.HeaderXForwardedFor]
if len(prior) > 0 {
req.Header.Set(echo.HeaderXForwardedFor, strings.Join(prior, ", ")+", "+c.RealIP())
} else {
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
}
}

retries := config.RetryCount
Expand Down
106 changes: 106 additions & 0 deletions middleware/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"net/http/httptest"
"net/url"
"regexp"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -1164,3 +1165,108 @@ func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}

// TestProxyWebSocketXForwardedFor verifies that for WebSocket Upgrade requests,
// the proxy middleware appends c.RealIP() to any existing X-Forwarded-For chain,
// mirroring net/http/httputil.(*ProxyRequest).SetXForwarded used by the HTTP path.
//
// Regression guard for the previous "set only if empty" behavior, which dropped
// the proxy's own peer IP from the chain whenever upstream proxies had already
// added entries.
func TestProxyWebSocketXForwardedFor(t *testing.T) {
tests := []struct {
name string
incomingXFF []string // nil = no incoming X-Forwarded-For header at all
wantPrefix string // expected join of entries preceding the appended proxy RealIP
}{
{
name: "no incoming XFF, only proxy RealIP is set",
incomingXFF: nil,
wantPrefix: "",
},
{
name: "single-line single-entry XFF is preserved with proxy RealIP appended",
incomingXFF: []string{"203.0.113.1"},
wantPrefix: "203.0.113.1",
},
{
name: "single-line comma-separated XFF is preserved with proxy RealIP appended",
incomingXFF: []string{"203.0.113.1, 10.0.0.5"},
wantPrefix: "203.0.113.1, 10.0.0.5",
},
{
name: "multi-line XFF (multiple header occurrences) is joined with proxy RealIP appended",
incomingXFF: []string{"203.0.113.1", "10.0.0.5"},
wantPrefix: "203.0.113.1, 10.0.0.5",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Buffered so the upstream handler never blocks before the client reads.
headerCh := make(chan http.Header, 1)

upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsHandler := func(conn *websocket.Conn) {
headerCh <- conn.Request().Header.Clone()
defer conn.Close()
var msg string
if err := websocket.Message.Receive(conn, &msg); err == nil {
_ = websocket.Message.Send(conn, msg)
}
}
websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
}))
defer upstream.Close()

tgtURL, _ := url.Parse(upstream.URL)
e := echo.New()
e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})}))
proxySrv := httptest.NewServer(e)
defer proxySrv.Close()

proxyWSURL, _ := url.Parse(proxySrv.URL)
proxyWSURL.Scheme = "ws"

origin, _ := url.Parse(proxySrv.URL)
cfg := &websocket.Config{
Location: proxyWSURL,
Origin: origin,
Version: websocket.ProtocolVersionHybi13,
Header: http.Header{},
}
for _, v := range tt.incomingXFF {
cfg.Header.Add(echo.HeaderXForwardedFor, v)
}

wsConn, err := websocket.DialConfig(cfg)
assert.NoError(t, err)
defer wsConn.Close()

assert.NoError(t, websocket.Message.Send(wsConn, "ping"))
var got string
assert.NoError(t, websocket.Message.Receive(wsConn, &got))

// The handler sends to headerCh before echoing, so it arrives before Receive returns.
captured := <-headerCh
xff := captured.Get(echo.HeaderXForwardedFor)

// The middleware uses Header.Set, so the upstream sees exactly one
// X-Forwarded-For header line. Split it back into entries.
entries := strings.Split(xff, ", ")
assert.NotEmpty(t, entries, "X-Forwarded-For must be set by the proxy middleware")

// The tail entry is the proxy's c.RealIP(). When the test client dials
// via httptest.NewServer the proxy sees 127.0.0.1.
tail := entries[len(entries)-1]
assert.Equal(t, "127.0.0.1", tail,
"proxy RealIP must be appended at the tail of X-Forwarded-For")

// The remaining entries must equal the prior chain, preserving order
// and joining multi-line headers with ", ".
gotPrefix := strings.Join(entries[:len(entries)-1], ", ")
assert.Equal(t, tt.wantPrefix, gotPrefix,
"prior X-Forwarded-For entries must be preserved before the appended RealIP")
})
}
}