Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 57 additions & 55 deletions internal/middleware/wireguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,14 @@ func WireGuardMiddlewareWithProxy(wireGuardCIDR string, allow bool) func(http.Ha
}
}

// WireGuardMiddleware enforces policies based on the interface and subnet.
// WireGuardMiddlewareWithInterface enforces policies based on the client's IP and the WireGuard subnet.
// It allows requests if the CLIENT IP is either:
// 1. In the WireGuard subnet (e.g., 100.97.0.0/16), OR
// 2. Arriving on the specified WireGuard interface
//
// This ensures that nodes can access cloud-init either:
// - Through their WireGuard tunnel (client IP in WireGuard subnet)
// - Directly on the server's WireGuard interface
func WireGuardMiddlewareWithInterface(wireGuardInterface string, wireGuardCIDR string) func(http.Handler) http.Handler {
// Parse the WireGuard CIDR into a *net.IPNet
_, wgNet, err := net.ParseCIDR(wireGuardCIDR)
Expand All @@ -83,43 +90,13 @@ func WireGuardMiddlewareWithInterface(wireGuardInterface string, wireGuardCIDR s

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Retrieve the local address (where the request arrived)
localAddr := r.Context().Value(http.LocalAddrContextKey)
if localAddr == nil {
log.Debug().Msg("Could not determine local address")
http.Error(w, "Could not determine local address", http.StatusForbidden)
return
}

// Assert net.Addr from the context
addr, ok := localAddr.(net.Addr)
if !ok {
log.Debug().Msg("Invalid address format")
http.Error(w, "Invalid address format", http.StatusForbidden)
return
}

// Extract the local IP
localIP, _, err := net.SplitHostPort(addr.String())
if err != nil {
log.Debug().Msg("Could not extract IP from address")
http.Error(w, "Could not extract IP from address", http.StatusForbidden)
return
}

ip := net.ParseIP(localIP)
if ip == nil {
log.Debug().Msg("Invalid IP Address")
http.Error(w, "Invalid IP Address", http.StatusForbidden)
return
}

// Extract client IP from request
var clientIP string

// Check for X-Forwarded-For header
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
clientIP = strings.Split(xff, ",")[0]
clientIP = strings.TrimSpace(strings.Split(xff, ",")[0])
}

// Check for Forwarded header
Expand All @@ -137,33 +114,54 @@ func WireGuardMiddlewareWithInterface(wireGuardInterface string, wireGuardCIDR s

// Fallback to RemoteAddr
if clientIP == "" {
var err error
clientIP, _, err = net.SplitHostPort(r.RemoteAddr)
if err != nil {
log.Debug().Err(err).Msg("Invalid Remote Address")
http.Error(w, "Invalid Remote Address", http.StatusForbidden)
return
}
}

// Check if the IP matches the WireGuard subnet
isInWireGuardSubnet := wgNet.Contains(ip)
var recievedInterface net.Interface

// Check if the request arrived on the WireGuard interface
isOnWireGuardInterface := false
interfaces, err := net.Interfaces()
if err != nil {
log.Debug().Msg("Could not retrieve network interfaces")
http.Error(w, "Could not retrieve network interfaces", http.StatusInternalServerError)
// Parse client IP
clientIPParsed := net.ParseIP(clientIP)
if clientIPParsed == nil {
log.Debug().Str("clientIP", clientIP).Msg("Invalid client IP Address")
http.Error(w, "Invalid IP Address", http.StatusForbidden)
return
}
for _, iface := range interfaces {
addrs, _ := iface.Addrs() // Ignoring error on Addrs() as we can still check other interfaces
for _, ifaceAddr := range addrs {
if ipNet, ok := ifaceAddr.(*net.IPNet); ok && ipNet.IP.Equal(ip) {
recievedInterface = iface
if iface.Name == wireGuardInterface {
isOnWireGuardInterface = true
break

// Check if CLIENT IP is in WireGuard subnet
isInWireGuardSubnet := wgNet.Contains(clientIPParsed)

// Retrieve the local address (where the request arrived on the server)
var localIP string
var isOnWireGuardInterface bool
var receivedInterface string

localAddr := r.Context().Value(http.LocalAddrContextKey)
if localAddr != nil {
if addr, ok := localAddr.(net.Addr); ok {
localIP, _, _ = net.SplitHostPort(addr.String())
if localIPParsed := net.ParseIP(localIP); localIPParsed != nil {
// Check if the request arrived on the WireGuard interface
interfaces, err := net.Interfaces()
if err == nil {
for _, iface := range interfaces {
addrs, _ := iface.Addrs() // Ignoring error on Addrs() as we can still check other interfaces
for _, ifaceAddr := range addrs {
if ipNet, ok := ifaceAddr.(*net.IPNet); ok && ipNet.IP.Equal(localIPParsed) {
receivedInterface = iface.Name
if iface.Name == wireGuardInterface {
isOnWireGuardInterface = true
break
}
}
}
if isOnWireGuardInterface {
break
}
}
}
}
}
Expand All @@ -172,15 +170,19 @@ func WireGuardMiddlewareWithInterface(wireGuardInterface string, wireGuardCIDR s
log.Debug().
Str("localIP", localIP).
Str("clientIP", clientIP).
Str("interface", recievedInterface.Name).
Str("interface", receivedInterface).
Bool("isInWireGuardSubnet", isInWireGuardSubnet).
Bool("isOnWireGuardInterface", isOnWireGuardInterface).
Msg("WireGuard policy check")

// Enforce the policy: deny if neither condition is true
// Enforce the policy: allow if CLIENT IP is in WireGuard subnet OR request arrived on WireGuard interface
if !isInWireGuardSubnet && !isOnWireGuardInterface {
log.Debug().Msgf("Access denied: IP %s not in WireGuard subnet or interface", localIP)
http.Error(w, fmt.Sprintf("Access denied: IP %s not in WireGuard subnet or interface", localIP), http.StatusForbidden)
log.Debug().
Str("clientIP", clientIP).
Str("localIP", localIP).
Str("interface", receivedInterface).
Msgf("Access denied: client IP %s not in WireGuard subnet and request not on WireGuard interface", clientIP)
http.Error(w, fmt.Sprintf("Access denied: client IP %s not in WireGuard subnet or on WireGuard interface", clientIP), http.StatusForbidden)
return
}

Expand Down
Loading
Loading