diff --git a/internal/middleware/wireguard.go b/internal/middleware/wireguard.go index 03996c29..a9a40130 100644 --- a/internal/middleware/wireguard.go +++ b/internal/middleware/wireguard.go @@ -73,7 +73,14 @@ func WireGuardMiddlewareWithProxy(wireGuardCIDR string, allow bool) func(http.Ha } } -// WireGuardMiddleware enforces policies based on the interface and subnet. +// WireGuardMiddlewareWithInterface enforces policies based on the client's IP and the WireGuard subnet. +// It allows requests if the CLIENT IP is either: +// 1. In the WireGuard subnet (e.g., 100.97.0.0/16), OR +// 2. Arriving on the specified WireGuard interface +// +// This ensures that nodes can access cloud-init either: +// - Through their WireGuard tunnel (client IP in WireGuard subnet) +// - Directly on the server's WireGuard interface func WireGuardMiddlewareWithInterface(wireGuardInterface string, wireGuardCIDR string) func(http.Handler) http.Handler { // Parse the WireGuard CIDR into a *net.IPNet _, wgNet, err := net.ParseCIDR(wireGuardCIDR) @@ -83,43 +90,13 @@ func WireGuardMiddlewareWithInterface(wireGuardInterface string, wireGuardCIDR s return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Retrieve the local address (where the request arrived) - localAddr := r.Context().Value(http.LocalAddrContextKey) - if localAddr == nil { - log.Debug().Msg("Could not determine local address") - http.Error(w, "Could not determine local address", http.StatusForbidden) - return - } - - // Assert net.Addr from the context - addr, ok := localAddr.(net.Addr) - if !ok { - log.Debug().Msg("Invalid address format") - http.Error(w, "Invalid address format", http.StatusForbidden) - return - } - - // Extract the local IP - localIP, _, err := net.SplitHostPort(addr.String()) - if err != nil { - log.Debug().Msg("Could not extract IP from address") - http.Error(w, "Could not extract IP from address", http.StatusForbidden) - return - } - - ip := net.ParseIP(localIP) - if ip == nil { - log.Debug().Msg("Invalid IP Address") - http.Error(w, "Invalid IP Address", http.StatusForbidden) - return - } - + // Extract client IP from request var clientIP string // Check for X-Forwarded-For header xff := r.Header.Get("X-Forwarded-For") if xff != "" { - clientIP = strings.Split(xff, ",")[0] + clientIP = strings.TrimSpace(strings.Split(xff, ",")[0]) } // Check for Forwarded header @@ -137,33 +114,54 @@ func WireGuardMiddlewareWithInterface(wireGuardInterface string, wireGuardCIDR s // Fallback to RemoteAddr if clientIP == "" { + var err error clientIP, _, err = net.SplitHostPort(r.RemoteAddr) if err != nil { + log.Debug().Err(err).Msg("Invalid Remote Address") http.Error(w, "Invalid Remote Address", http.StatusForbidden) return } } - // Check if the IP matches the WireGuard subnet - isInWireGuardSubnet := wgNet.Contains(ip) - var recievedInterface net.Interface - - // Check if the request arrived on the WireGuard interface - isOnWireGuardInterface := false - interfaces, err := net.Interfaces() - if err != nil { - log.Debug().Msg("Could not retrieve network interfaces") - http.Error(w, "Could not retrieve network interfaces", http.StatusInternalServerError) + // Parse client IP + clientIPParsed := net.ParseIP(clientIP) + if clientIPParsed == nil { + log.Debug().Str("clientIP", clientIP).Msg("Invalid client IP Address") + http.Error(w, "Invalid IP Address", http.StatusForbidden) return } - for _, iface := range interfaces { - addrs, _ := iface.Addrs() // Ignoring error on Addrs() as we can still check other interfaces - for _, ifaceAddr := range addrs { - if ipNet, ok := ifaceAddr.(*net.IPNet); ok && ipNet.IP.Equal(ip) { - recievedInterface = iface - if iface.Name == wireGuardInterface { - isOnWireGuardInterface = true - break + + // Check if CLIENT IP is in WireGuard subnet + isInWireGuardSubnet := wgNet.Contains(clientIPParsed) + + // Retrieve the local address (where the request arrived on the server) + var localIP string + var isOnWireGuardInterface bool + var receivedInterface string + + localAddr := r.Context().Value(http.LocalAddrContextKey) + if localAddr != nil { + if addr, ok := localAddr.(net.Addr); ok { + localIP, _, _ = net.SplitHostPort(addr.String()) + if localIPParsed := net.ParseIP(localIP); localIPParsed != nil { + // Check if the request arrived on the WireGuard interface + interfaces, err := net.Interfaces() + if err == nil { + for _, iface := range interfaces { + addrs, _ := iface.Addrs() // Ignoring error on Addrs() as we can still check other interfaces + for _, ifaceAddr := range addrs { + if ipNet, ok := ifaceAddr.(*net.IPNet); ok && ipNet.IP.Equal(localIPParsed) { + receivedInterface = iface.Name + if iface.Name == wireGuardInterface { + isOnWireGuardInterface = true + break + } + } + } + if isOnWireGuardInterface { + break + } + } } } } @@ -172,15 +170,19 @@ func WireGuardMiddlewareWithInterface(wireGuardInterface string, wireGuardCIDR s log.Debug(). Str("localIP", localIP). Str("clientIP", clientIP). - Str("interface", recievedInterface.Name). + Str("interface", receivedInterface). Bool("isInWireGuardSubnet", isInWireGuardSubnet). Bool("isOnWireGuardInterface", isOnWireGuardInterface). Msg("WireGuard policy check") - // Enforce the policy: deny if neither condition is true + // Enforce the policy: allow if CLIENT IP is in WireGuard subnet OR request arrived on WireGuard interface if !isInWireGuardSubnet && !isOnWireGuardInterface { - log.Debug().Msgf("Access denied: IP %s not in WireGuard subnet or interface", localIP) - http.Error(w, fmt.Sprintf("Access denied: IP %s not in WireGuard subnet or interface", localIP), http.StatusForbidden) + log.Debug(). + Str("clientIP", clientIP). + Str("localIP", localIP). + Str("interface", receivedInterface). + Msgf("Access denied: client IP %s not in WireGuard subnet and request not on WireGuard interface", clientIP) + http.Error(w, fmt.Sprintf("Access denied: client IP %s not in WireGuard subnet or on WireGuard interface", clientIP), http.StatusForbidden) return } diff --git a/internal/middleware/wireguard_test.go b/internal/middleware/wireguard_test.go new file mode 100644 index 00000000..755e3070 --- /dev/null +++ b/internal/middleware/wireguard_test.go @@ -0,0 +1,390 @@ +package middleware + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "testing" +) + +// mockAddr implements net.Addr interface for testing +type mockAddr struct { + network string + address string +} + +func (m mockAddr) Network() string { + return m.network +} + +func (m mockAddr) String() string { + return m.address +} + +// TestWireGuardMiddlewareWithProxy tests the proxy-based WireGuard middleware +func TestWireGuardMiddlewareWithProxy(t *testing.T) { + testCases := []struct { + name string + wireGuardCIDR string + allow bool + clientIP string + xff string + forwarded string + expectedStatus int + expectedBody string + }{ + { + name: "Allow client in WireGuard subnet", + wireGuardCIDR: "100.97.0.0/16", + allow: true, + clientIP: "100.97.0.5", + expectedStatus: http.StatusOK, + expectedBody: "OK", + }, + { + name: "Deny client not in WireGuard subnet when allow=true", + wireGuardCIDR: "100.97.0.0/16", + allow: true, + clientIP: "192.168.1.10", + expectedStatus: http.StatusForbidden, + expectedBody: "Access denied: Not in WireGuard subnet\n", + }, + { + name: "Allow client not in WireGuard subnet when allow=false", + wireGuardCIDR: "100.97.0.0/16", + allow: false, + clientIP: "192.168.1.10", + expectedStatus: http.StatusOK, + expectedBody: "OK", + }, + { + name: "Deny client in WireGuard subnet when allow=false", + wireGuardCIDR: "100.97.0.0/16", + allow: false, + clientIP: "100.97.0.5", + expectedStatus: http.StatusForbidden, + expectedBody: "Access denied: WireGuard traffic not allowed\n", + }, + { + name: "Use X-Forwarded-For header", + wireGuardCIDR: "100.97.0.0/16", + allow: true, + clientIP: "192.168.1.10", // RemoteAddr (should be ignored) + xff: "100.97.0.20", // X-Forwarded-For (should be used) + expectedStatus: http.StatusOK, + expectedBody: "OK", + }, + { + name: "Use Forwarded header", + wireGuardCIDR: "100.97.0.0/16", + allow: true, + clientIP: "192.168.1.10", // RemoteAddr (should be ignored) + forwarded: "for=100.97.0.30", // Forwarded (should be used) + expectedStatus: http.StatusOK, + expectedBody: "OK", + }, + { + name: "X-Forwarded-For takes precedence over Forwarded", + wireGuardCIDR: "100.97.0.0/16", + allow: true, + clientIP: "192.168.1.10", + xff: "100.97.0.40", + forwarded: "for=192.168.1.20", + expectedStatus: http.StatusOK, + expectedBody: "OK", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a test handler that will be wrapped by the middleware + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + // Apply middleware + middleware := WireGuardMiddlewareWithProxy(tc.wireGuardCIDR, tc.allow) + wrappedHandler := middleware(handler) + + // Create test request + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = tc.clientIP + ":12345" + if tc.xff != "" { + req.Header.Set("X-Forwarded-For", tc.xff) + } + if tc.forwarded != "" { + req.Header.Set("Forwarded", tc.forwarded) + } + + // Record response + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + + // Check status code + if rr.Code != tc.expectedStatus { + t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code) + } + + // Check body + if rr.Body.String() != tc.expectedBody { + t.Errorf("Expected body %q, got %q", tc.expectedBody, rr.Body.String()) + } + }) + } +} + +// TestWireGuardMiddlewareWithInterface tests the interface-based WireGuard middleware +func TestWireGuardMiddlewareWithInterface(t *testing.T) { + testCases := []struct { + name string + wireGuardCIDR string + wireGuardInterface string + clientIP string + xff string + localAddr net.Addr + expectedStatus int + description string + }{ + { + name: "Allow client in WireGuard subnet", + wireGuardCIDR: "100.97.0.0/16", + wireGuardInterface: "wg0", + clientIP: "100.97.0.5", + localAddr: mockAddr{"tcp", "192.168.1.100:27777"}, + expectedStatus: http.StatusOK, + description: "Client with WireGuard IP should be allowed regardless of interface", + }, + { + name: "Deny client not in WireGuard subnet on non-WireGuard interface", + wireGuardCIDR: "100.97.0.0/16", + wireGuardInterface: "wg0", + clientIP: "192.168.1.10", + localAddr: mockAddr{"tcp", "192.168.1.100:27777"}, + expectedStatus: http.StatusForbidden, + description: "Client without WireGuard IP on regular interface should be denied", + }, + { + name: "Allow client with X-Forwarded-For in WireGuard subnet", + wireGuardCIDR: "100.97.0.0/16", + wireGuardInterface: "wg0", + clientIP: "192.168.1.10", // RemoteAddr + xff: "100.97.0.20", // X-Forwarded-For + localAddr: mockAddr{"tcp", "192.168.1.100:27777"}, + expectedStatus: http.StatusOK, + description: "Should use X-Forwarded-For when present", + }, + { + name: "Allow multiple IPs in X-Forwarded-For (use first)", + wireGuardCIDR: "100.97.0.0/16", + wireGuardInterface: "wg0", + clientIP: "192.168.1.10", + xff: "100.97.0.30, 10.0.0.1, 172.16.0.1", + localAddr: mockAddr{"tcp", "192.168.1.100:27777"}, + expectedStatus: http.StatusOK, + description: "Should use first IP in X-Forwarded-For chain", + }, + { + name: "Deny invalid client IP", + wireGuardCIDR: "100.97.0.0/16", + wireGuardInterface: "wg0", + clientIP: "invalid-ip", + localAddr: mockAddr{"tcp", "192.168.1.100:27777"}, + expectedStatus: http.StatusForbidden, + description: "Invalid IP should be rejected", + }, + { + name: "Allow client at edge of subnet", + wireGuardCIDR: "100.97.0.0/16", + wireGuardInterface: "wg0", + clientIP: "100.97.255.255", + localAddr: mockAddr{"tcp", "192.168.1.100:27777"}, + expectedStatus: http.StatusOK, + description: "IP at edge of subnet should be allowed", + }, + { + name: "Deny client just outside subnet", + wireGuardCIDR: "100.97.0.0/16", + wireGuardInterface: "wg0", + clientIP: "100.98.0.1", + localAddr: mockAddr{"tcp", "192.168.1.100:27777"}, + expectedStatus: http.StatusForbidden, + description: "IP just outside subnet should be denied", + }, + { + name: "Allow client with /24 subnet", + wireGuardCIDR: "10.89.0.0/24", + wireGuardInterface: "wg0", + clientIP: "10.89.0.50", + localAddr: mockAddr{"tcp", "192.168.1.100:27777"}, + expectedStatus: http.StatusOK, + description: "Should work with smaller subnets", + }, + { + name: "Handle no local address gracefully", + wireGuardCIDR: "100.97.0.0/16", + wireGuardInterface: "wg0", + clientIP: "100.97.0.5", + localAddr: nil, + expectedStatus: http.StatusOK, + description: "Should still allow based on client IP when local addr is missing", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a test handler + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + // Apply middleware + middleware := WireGuardMiddlewareWithInterface(tc.wireGuardInterface, tc.wireGuardCIDR) + wrappedHandler := middleware(handler) + + // Create test request + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = tc.clientIP + ":12345" + if tc.xff != "" { + req.Header.Set("X-Forwarded-For", tc.xff) + } + + // Add local address to context if provided + if tc.localAddr != nil { + ctx := context.WithValue(req.Context(), http.LocalAddrContextKey, tc.localAddr) + req = req.WithContext(ctx) + } + + // Record response + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + + // Check status code + if rr.Code != tc.expectedStatus { + t.Errorf("%s: Expected status %d, got %d", tc.description, tc.expectedStatus, rr.Code) + } + }) + } +} + +// TestWireGuardMiddlewareWithProxy_InvalidCIDR tests panic on invalid CIDR +func TestWireGuardMiddlewareWithProxy_InvalidCIDR(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic with invalid CIDR, but didn't panic") + } + }() + + // This should panic + _ = WireGuardMiddlewareWithProxy("invalid-cidr", true) +} + +// TestWireGuardMiddlewareWithInterface_InvalidCIDR tests panic on invalid CIDR +func TestWireGuardMiddlewareWithInterface_InvalidCIDR(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic with invalid CIDR, but didn't panic") + } + }() + + // This should panic + _ = WireGuardMiddlewareWithInterface("wg0", "invalid-cidr") +} + +// TestWireGuardMiddleware_RealWorldScenario simulates the bug scenario from v1.4.1 +func TestWireGuardMiddleware_RealWorldScenario(t *testing.T) { + // Scenario: 200 nodes trying to get cloud-init data + // - WireGuard subnet: 100.97.0.0/16 + // - Nodes have regular IPs: 192.168.1.0/24 + // - Some nodes have established WireGuard tunnels and got IPs in 100.97.0.0/16 + // - Server listens on 192.168.1.100:27777 + + middleware := WireGuardMiddlewareWithInterface("wg0", "100.97.0.0/16") + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("cloud-init-data")) + }) + wrappedHandler := middleware(handler) + + testCases := []struct { + name string + nodeIP string + hasWireGuard bool + wgIP string + expectedStatus int + description string + }{ + { + name: "Node with WireGuard tunnel", + nodeIP: "192.168.1.10", + hasWireGuard: true, + wgIP: "100.97.0.5", + expectedStatus: http.StatusOK, + description: "Node that established WireGuard tunnel should be allowed", + }, + { + name: "Node without WireGuard tunnel", + nodeIP: "192.168.1.11", + hasWireGuard: false, + expectedStatus: http.StatusForbidden, + description: "Node without WireGuard tunnel should be denied", + }, + { + name: "Node with WireGuard at edge of subnet", + nodeIP: "192.168.1.12", + hasWireGuard: true, + wgIP: "100.97.255.254", + expectedStatus: http.StatusOK, + description: "Node with WireGuard IP at edge should be allowed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/user-data", nil) + + // Simulate the request coming from the node + if tc.hasWireGuard { + // Node uses WireGuard IP as source + req.RemoteAddr = tc.wgIP + ":54321" + } else { + // Node uses regular IP as source + req.RemoteAddr = tc.nodeIP + ":54321" + } + + // Server received request on its eth0 interface + ctx := context.WithValue(req.Context(), http.LocalAddrContextKey, mockAddr{"tcp", "192.168.1.100:27777"}) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + + if rr.Code != tc.expectedStatus { + t.Errorf("%s: Expected status %d, got %d. Body: %s", + tc.description, tc.expectedStatus, rr.Code, rr.Body.String()) + } + }) + } +} + +// BenchmarkWireGuardMiddlewareWithInterface benchmarks the middleware performance +func BenchmarkWireGuardMiddlewareWithInterface(b *testing.B) { + middleware := WireGuardMiddlewareWithInterface("wg0", "100.97.0.0/16") + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + wrappedHandler := middleware(handler) + + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "100.97.0.5:12345" + ctx := context.WithValue(req.Context(), http.LocalAddrContextKey, mockAddr{"tcp", "192.168.1.100:27777"}) + req = req.WithContext(ctx) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + } +}