diff --git a/middleware/proxy.go b/middleware/proxy.go index a40d58130..497aefea4 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -355,8 +355,14 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if req.Header.Get(echo.HeaderXForwardedProto) == "" { req.Header.Set(echo.HeaderXForwardedProto, c.Scheme()) } - if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy. - req.Header.Set(echo.HeaderXForwardedFor, c.RealIP()) + if c.IsWebSocket() { // For HTTP, this is set by Go HTTP reverse proxy. + // Append, not set, to preserve the incoming chain from upstream proxies. + prior := req.Header[echo.HeaderXForwardedFor] + if len(prior) > 0 { + req.Header.Set(echo.HeaderXForwardedFor, strings.Join(prior, ", ")+", "+c.RealIP()) + } else { + req.Header.Set(echo.HeaderXForwardedFor, c.RealIP()) + } } retries := config.RetryCount diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 5494b23ba..5053f7945 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -15,6 +15,7 @@ import ( "net/http/httptest" "net/url" "regexp" + "strings" "sync" "testing" "time" @@ -1164,3 +1165,108 @@ func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) { assert.NoError(t, err) assert.Equal(t, sendMsg, recvMsg) } + +// TestProxyWebSocketXForwardedFor verifies that for WebSocket Upgrade requests, +// the proxy middleware appends c.RealIP() to any existing X-Forwarded-For chain, +// mirroring net/http/httputil.(*ProxyRequest).SetXForwarded used by the HTTP path. +// +// Regression guard for the previous "set only if empty" behavior, which dropped +// the proxy's own peer IP from the chain whenever upstream proxies had already +// added entries. +func TestProxyWebSocketXForwardedFor(t *testing.T) { + tests := []struct { + name string + incomingXFF []string // nil = no incoming X-Forwarded-For header at all + wantPrefix string // expected join of entries preceding the appended proxy RealIP + }{ + { + name: "no incoming XFF, only proxy RealIP is set", + incomingXFF: nil, + wantPrefix: "", + }, + { + name: "single-line single-entry XFF is preserved with proxy RealIP appended", + incomingXFF: []string{"203.0.113.1"}, + wantPrefix: "203.0.113.1", + }, + { + name: "single-line comma-separated XFF is preserved with proxy RealIP appended", + incomingXFF: []string{"203.0.113.1, 10.0.0.5"}, + wantPrefix: "203.0.113.1, 10.0.0.5", + }, + { + name: "multi-line XFF (multiple header occurrences) is joined with proxy RealIP appended", + incomingXFF: []string{"203.0.113.1", "10.0.0.5"}, + wantPrefix: "203.0.113.1, 10.0.0.5", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Buffered so the upstream handler never blocks before the client reads. + headerCh := make(chan http.Header, 1) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsHandler := func(conn *websocket.Conn) { + headerCh <- conn.Request().Header.Clone() + defer conn.Close() + var msg string + if err := websocket.Message.Receive(conn, &msg); err == nil { + _ = websocket.Message.Send(conn, msg) + } + } + websocket.Server{Handler: wsHandler}.ServeHTTP(w, r) + })) + defer upstream.Close() + + tgtURL, _ := url.Parse(upstream.URL) + e := echo.New() + e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})})) + proxySrv := httptest.NewServer(e) + defer proxySrv.Close() + + proxyWSURL, _ := url.Parse(proxySrv.URL) + proxyWSURL.Scheme = "ws" + + origin, _ := url.Parse(proxySrv.URL) + cfg := &websocket.Config{ + Location: proxyWSURL, + Origin: origin, + Version: websocket.ProtocolVersionHybi13, + Header: http.Header{}, + } + for _, v := range tt.incomingXFF { + cfg.Header.Add(echo.HeaderXForwardedFor, v) + } + + wsConn, err := websocket.DialConfig(cfg) + assert.NoError(t, err) + defer wsConn.Close() + + assert.NoError(t, websocket.Message.Send(wsConn, "ping")) + var got string + assert.NoError(t, websocket.Message.Receive(wsConn, &got)) + + // The handler sends to headerCh before echoing, so it arrives before Receive returns. + captured := <-headerCh + xff := captured.Get(echo.HeaderXForwardedFor) + + // The middleware uses Header.Set, so the upstream sees exactly one + // X-Forwarded-For header line. Split it back into entries. + entries := strings.Split(xff, ", ") + assert.NotEmpty(t, entries, "X-Forwarded-For must be set by the proxy middleware") + + // The tail entry is the proxy's c.RealIP(). When the test client dials + // via httptest.NewServer the proxy sees 127.0.0.1. + tail := entries[len(entries)-1] + assert.Equal(t, "127.0.0.1", tail, + "proxy RealIP must be appended at the tail of X-Forwarded-For") + + // The remaining entries must equal the prior chain, preserving order + // and joining multi-line headers with ", ". + gotPrefix := strings.Join(entries[:len(entries)-1], ", ") + assert.Equal(t, tt.wantPrefix, gotPrefix, + "prior X-Forwarded-For entries must be preserved before the appended RealIP") + }) + } +}