diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 4648c496..09f7bd79 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -439,11 +439,7 @@ func (c *certificateV2) validate() error { if !hasV6Networks { return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network) } - } else if network.Addr().Is4() { - if !hasV4Networks { - return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network) - } - } + } // as long as we have any IP address, IPv4 UnsafeNetworks are allowed } } diff --git a/e2e/snat_test.go b/e2e/snat_test.go new file mode 100644 index 00000000..a0e0af96 --- /dev/null +++ b/e2e/snat_test.go @@ -0,0 +1,400 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "encoding/binary" + "net/netip" + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/e2e/router" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// parseIPv4UDPPacket extracts source/dest IPs, ports, and payload from an IPv4 UDP packet. +func parseIPv4UDPPacket(t testing.TB, pkt []byte) (srcIP, dstIP netip.Addr, srcPort, dstPort uint16, payload []byte) { + t.Helper() + require.True(t, len(pkt) >= 28, "packet too short for IPv4+UDP header") + require.Equal(t, byte(0x45), pkt[0]&0xF0|pkt[0]&0x0F, "not a simple IPv4 packet (IHL!=5)") + + srcIP, _ = netip.AddrFromSlice(pkt[12:16]) + dstIP, _ = netip.AddrFromSlice(pkt[16:20]) + + ihl := int(pkt[0]&0x0F) * 4 + require.True(t, len(pkt) >= ihl+8, "packet too short for UDP header") + srcPort = binary.BigEndian.Uint16(pkt[ihl : ihl+2]) + dstPort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4]) + udpLen := binary.BigEndian.Uint16(pkt[ihl+4 : ihl+6]) + payload = pkt[ihl+8 : ihl+int(udpLen)] + return +} + +func TestSNAT_IPv6OnlyPeer_IPv4UnsafeTraffic(t *testing.T) { + // Scenario: Two IPv6-only VPN nodes. The "router" node has unsafe networks + // (192.168.0.0/16) in its cert and a configured SNAT address. The "sender" + // node has an unsafe route for 192.168.0.0/16 via the router. + // + // When sender injects an IPv4 packet destined for the unsafe network, it + // gets tunneled to the router. The router's firewall detects this is IPv4 + // from an IPv6-only peer and applies SNAT, rewriting the source IP to the + // SNAT address before delivering it to TUN. + // + // When a reply comes back from TUN addressed to the SNAT address, the + // router un-SNATs it (restoring the original destination) and tunnels it + // back to the sender. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + unsafePrefix := "192.168.0.0/16" + snatAddr := netip.MustParseAddr("169.254.42.42") + + // Router: IPv6-only with unsafe networks and a manual SNAT address. + // Override inbound firewall with local_cidr: "any" so both IPv4 (unsafe) + // and IPv6 (VPN) traffic is accepted. + routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks( + cert.Version2, ca, caKey, "router", "ff::1/64", + netip.MustParseAddrPort("[beef::1]:4242"), + unsafePrefix, + m{ + "firewall": m{ + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "any", + }}, + }, + "tun": m{ + "snat_address_for_4over6": snatAddr.String(), + }, + }, + ) + + // Sender: IPv6-only with an unsafe route via the router + senderControl, _, _, _ := newSimpleServerWithUdp( + cert.Version2, ca, caKey, "sender", "ff::2/64", + netip.MustParseAddrPort("[beef::2]:4242"), + m{ + "tun": m{ + "unsafe_routes": []m{ + {"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()}, + }, + }, + }, + ) + + // Tell sender where the router lives + senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr) + + // Build the router and start both nodes + r := router.NewR(t, routerControl, senderControl) + defer r.RenderFlow() + + routerControl.Start() + senderControl.Start() + + // --- Outbound: sender -> IPv4 unsafe dest (via router with SNAT) --- + + origSrcIP := netip.MustParseAddr("10.0.0.1") + unsafeDest := netip.MustParseAddr("192.168.1.1") + var origSrcPort uint16 = 12345 + var dstPort uint16 = 80 + + t.Log("Sender injects an IPv4 packet to the unsafe network") + senderControl.InjectTunUDPPacket(unsafeDest, dstPort, origSrcIP, origSrcPort, []byte("snat me")) + + t.Log("Route packets (handshake + data) until the router gets the packet on TUN") + snatPkt := r.RouteForAllUntilTxTun(routerControl) + + t.Log("Verify the packet was SNATted") + gotSrcIP, gotDstIP, gotSrcPort, gotDstPort, gotPayload := parseIPv4UDPPacket(t, snatPkt) + assert.Equal(t, snatAddr, gotSrcIP, "source IP should be rewritten to the SNAT address") + assert.Equal(t, unsafeDest, gotDstIP, "destination IP should be unchanged") + assert.Equal(t, dstPort, gotDstPort, "destination port should be unchanged") + assert.Equal(t, []byte("snat me"), gotPayload, "payload should be unchanged") + + // Capture the SNAT port (may differ from original if port was remapped) + snatPort := gotSrcPort + t.Logf("SNAT port: %d (original: %d)", snatPort, origSrcPort) + + // --- Return: reply from unsafe dest -> un-SNATted back to sender --- + + t.Log("Router injects a reply packet from the unsafe dest to the SNAT address") + routerControl.InjectTunUDPPacket(snatAddr, snatPort, unsafeDest, dstPort, []byte("reply from unsafe")) + + t.Log("Route until sender gets the reply on TUN") + replyPkt := r.RouteForAllUntilTxTun(senderControl) + + t.Log("Verify the reply was un-SNATted") + replySrcIP, replyDstIP, replySrcPort, replyDstPort, replyPayload := parseIPv4UDPPacket(t, replyPkt) + assert.Equal(t, unsafeDest, replySrcIP, "reply source should be the unsafe dest") + assert.Equal(t, origSrcIP, replyDstIP, "reply dest should be the original source IP (un-SNATted)") + assert.Equal(t, dstPort, replySrcPort, "reply source port should be the unsafe dest port") + assert.Equal(t, origSrcPort, replyDstPort, "reply dest port should be the original source port (un-SNATted)") + assert.Equal(t, []byte("reply from unsafe"), replyPayload, "payload should be unchanged") + + r.RenderHostmaps("Final hostmaps", routerControl, senderControl) + + // Also verify normal IPv6 VPN traffic still works between the nodes + t.Log("Verify normal IPv6 VPN tunnel still works") + assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r) + + routerControl.Stop() + senderControl.Stop() +} + +func TestSNAT_MultipleFlows(t *testing.T) { + // Test that multiple distinct IPv4 flows from the same IPv6-only peer + // are tracked independently through SNAT. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + unsafePrefix := "192.168.0.0/16" + snatAddr := netip.MustParseAddr("169.254.42.42") + + routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks( + cert.Version2, ca, caKey, "router", "ff::1/64", + netip.MustParseAddrPort("[beef::1]:4242"), + unsafePrefix, + m{ + "firewall": m{ + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "any", + }}, + }, + "tun": m{ + "snat_address_for_4over6": snatAddr.String(), + }, + }, + ) + + senderControl, _, _, _ := newSimpleServerWithUdp( + cert.Version2, ca, caKey, "sender", "ff::2/64", + netip.MustParseAddrPort("[beef::2]:4242"), + m{ + "tun": m{ + "unsafe_routes": []m{ + {"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()}, + }, + }, + }, + ) + + senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr) + + r := router.NewR(t, routerControl, senderControl) + defer r.RenderFlow() + r.CancelFlowLogs() + + routerControl.Start() + senderControl.Start() + + unsafeDest := netip.MustParseAddr("192.168.1.1") + + // Send first flow + senderControl.InjectTunUDPPacket(unsafeDest, 80, netip.MustParseAddr("10.0.0.1"), 1111, []byte("flow1")) + pkt1 := r.RouteForAllUntilTxTun(routerControl) + srcIP1, _, srcPort1, _, payload1 := parseIPv4UDPPacket(t, pkt1) + assert.Equal(t, snatAddr, srcIP1) + assert.Equal(t, []byte("flow1"), payload1) + + // Send second flow (different source port) + senderControl.InjectTunUDPPacket(unsafeDest, 80, netip.MustParseAddr("10.0.0.1"), 2222, []byte("flow2")) + pkt2 := r.RouteForAllUntilTxTun(routerControl) + srcIP2, _, srcPort2, _, payload2 := parseIPv4UDPPacket(t, pkt2) + assert.Equal(t, snatAddr, srcIP2) + assert.Equal(t, []byte("flow2"), payload2) + + // The two flows should have different SNAT ports (since they're different conntracks) + t.Logf("Flow 1 SNAT port: %d, Flow 2 SNAT port: %d", srcPort1, srcPort2) + + // Reply to flow 2 first (out of order) + routerControl.InjectTunUDPPacket(snatAddr, srcPort2, unsafeDest, 80, []byte("reply2")) + reply2 := r.RouteForAllUntilTxTun(senderControl) + _, replyDst2, _, replyDstPort2, replyPayload2 := parseIPv4UDPPacket(t, reply2) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), replyDst2) + assert.Equal(t, uint16(2222), replyDstPort2, "reply to flow 2 should restore original port 2222") + assert.Equal(t, []byte("reply2"), replyPayload2) + + // Reply to flow 1 + routerControl.InjectTunUDPPacket(snatAddr, srcPort1, unsafeDest, 80, []byte("reply1")) + reply1 := r.RouteForAllUntilTxTun(senderControl) + _, replyDst1, _, replyDstPort1, replyPayload1 := parseIPv4UDPPacket(t, reply1) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), replyDst1) + assert.Equal(t, uint16(1111), replyDstPort1, "reply to flow 1 should restore original port 1111") + assert.Equal(t, []byte("reply1"), replyPayload1) + + routerControl.Stop() + senderControl.Stop() +} + +// --- Adversarial SNAT E2E Tests --- + +func TestSNAT_UnsolicitedReplyDropped(t *testing.T) { + // Without any outbound SNAT traffic, inject a packet from the router's TUN + // addressed to the SNAT address. The sender must never receive it because + // there's no conntrack entry to un-SNAT through. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + unsafePrefix := "192.168.0.0/16" + snatAddr := netip.MustParseAddr("169.254.42.42") + + routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks( + cert.Version2, ca, caKey, "router", "ff::1/64", + netip.MustParseAddrPort("[beef::1]:4242"), + unsafePrefix, + m{ + "firewall": m{ + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "any", + }}, + }, + "tun": m{ + "snat_address_for_4over6": snatAddr.String(), + }, + }, + ) + + senderControl, _, _, _ := newSimpleServerWithUdp( + cert.Version2, ca, caKey, "sender", "ff::2/64", + netip.MustParseAddrPort("[beef::2]:4242"), + m{ + "tun": m{ + "unsafe_routes": []m{ + {"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()}, + }, + }, + }, + ) + + senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr) + + r := router.NewR(t, routerControl, senderControl) + defer r.RenderFlow() + r.CancelFlowLogs() + + routerControl.Start() + senderControl.Start() + + // First establish the tunnel with normal IPv6 traffic so handshake completes + assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r) + + // Inject the unsolicited reply from router's TUN to the SNAT address. + // There is NO prior outbound SNAT flow, so no conntrack entry exists. + // The router should silently drop this because unSnat finds no matching conntrack. + routerControl.InjectTunUDPPacket(snatAddr, 55555, netip.MustParseAddr("192.168.1.1"), 80, []byte("unsolicited")) + + // Send a canary IPv6 VPN packet after the bad one. Since the router processes + // TUN packets sequentially, the canary arriving proves the bad packet was processed first. + senderVpnAddr := senderControl.GetVpnAddrs()[0] + routerControl.InjectTunUDPPacket(senderVpnAddr, 90, routerVpnIpNet[0].Addr(), 80, []byte("canary")) + canaryPkt := r.RouteForAllUntilTxTun(senderControl) + assertUdpPacket(t, []byte("canary"), canaryPkt, routerVpnIpNet[0].Addr(), senderVpnAddr, 80, 90) + + // The unsolicited packet should have been dropped — nothing else on sender's TUN + got := senderControl.GetFromTun(false) + assert.Nil(t, got, "sender should not receive unsolicited packet to SNAT address with no conntrack entry") + + routerControl.Stop() + senderControl.Stop() +} + +func TestSNAT_NonUnsafeDestDropped(t *testing.T) { + // An IPv6-only sender sends IPv4 traffic to a destination outside the router's + // unsafe networks (172.16.0.1 when unsafe is 192.168.0.0/16). The router should + // reject this because the local address is not routable. This verifies that + // willingToHandleLocalAddr enforces boundaries on what SNAT traffic can reach. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + unsafePrefix := "192.168.0.0/16" + snatAddr := netip.MustParseAddr("169.254.42.42") + + routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks( + cert.Version2, ca, caKey, "router", "ff::1/64", + netip.MustParseAddrPort("[beef::1]:4242"), + unsafePrefix, + m{ + "firewall": m{ + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "any", + }}, + }, + "tun": m{ + "snat_address_for_4over6": snatAddr.String(), + }, + }, + ) + + // Sender has unsafe routes for BOTH 192.168.0.0/16 AND 172.16.0.0/12 via router. + // This means the sender will route 172.16.0.1 through the tunnel to the router. + // But the router should reject it because 172.16.0.0/12 is NOT in its unsafe networks. + senderControl, _, _, _ := newSimpleServerWithUdp( + cert.Version2, ca, caKey, "sender", "ff::2/64", + netip.MustParseAddrPort("[beef::2]:4242"), + m{ + "tun": m{ + "unsafe_routes": []m{ + {"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()}, + {"route": "172.16.0.0/12", "via": routerVpnIpNet[0].Addr().String()}, + }, + }, + }, + ) + + senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr) + + r := router.NewR(t, routerControl, senderControl) + defer r.RenderFlow() + r.CancelFlowLogs() + + routerControl.Start() + senderControl.Start() + + // Establish the tunnel first + assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r) + + // Send to 172.16.0.1 (NOT in router's unsafe networks 192.168.0.0/16). + // The router should reject this at willingToHandleLocalAddr. + senderControl.InjectTunUDPPacket( + netip.MustParseAddr("172.16.0.1"), 80, + netip.MustParseAddr("10.0.0.1"), 12345, + []byte("wrong dest"), + ) + + // Send a canary to a valid unsafe destination to prove the bad packet was processed + senderControl.InjectTunUDPPacket( + netip.MustParseAddr("192.168.1.1"), 80, + netip.MustParseAddr("10.0.0.1"), 33333, + []byte("canary"), + ) + + // Route until the canary arrives — the 172.16.0.1 packet should have been + // processed and dropped before the canary gets through + canaryPkt := r.RouteForAllUntilTxTun(routerControl) + _, canaryDst, _, _, canaryPayload := parseIPv4UDPPacket(t, canaryPkt) + assert.Equal(t, netip.MustParseAddr("192.168.1.1"), canaryDst, "canary should arrive at the valid unsafe dest") + assert.Equal(t, []byte("canary"), canaryPayload) + + // No more packets — the 172.16.0.1 packet was dropped + got := routerControl.GetFromTun(false) + assert.Nil(t, got, "packet to non-unsafe destination 172.16.0.1 should be dropped by the router") + + routerControl.Stop() + senderControl.Stop() +} diff --git a/firewall.go b/firewall.go index 93b16891..e1f63b64 100644 --- a/firewall.go +++ b/firewall.go @@ -2,6 +2,7 @@ package nebula import ( "crypto/sha256" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -22,10 +23,35 @@ import ( "github.com/slackhq/nebula/firewall" ) +var ErrCannotSNAT = errors.New("cannot SNAT this packet") +var ErrSNATIdentityMismatch = errors.New("refusing to SNAT for mismatched host") +var ErrSNATAddressCollision = errors.New("refusing to accept an incoming packet with my SNAT address") + +const ipv4SourcePosition = 12 +const ipv4DestinationPosition = 16 +const sourcePortOffset = 0 +const destinationPortOffset = 2 + type FirewallInterface interface { AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error } +type snatInfo struct { + //Src is the source IP+port to write into unsafe-route-bound packet + Src netip.AddrPort + //SrcVpnIp is the overlay IP associated with this flow. It's needed to associate reply traffic so we can get it back to the right host. + SrcVpnIp netip.Addr + //SnatPort is the port to rewrite into an overlay-bound packet + SnatPort uint16 +} + +func (s *snatInfo) Valid() bool { + if s == nil { + return false + } + return s.Src.IsValid() +} + type conn struct { Expires time.Time // Time when this conntrack entry will expire @@ -34,6 +60,9 @@ type conn struct { // fields pack for free after the uint32 above incoming bool rulesVersion uint16 + + //for SNAT support + snat *snatInfo } // TODO: need conntrack max tracked connections handling @@ -66,6 +95,8 @@ type Firewall struct { defaultLocalCIDRAny bool incomingMetrics firewallMetrics outgoingMetrics firewallMetrics + unsafeIPv4Origin netip.Addr + snatAddr netip.Addr l *logrus.Logger } @@ -193,12 +224,12 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { certificate := cs.getCertificate(cert.Version2) - if certificate == nil { + if certificate == nil { //todo if config.initiating_version is set to 1, and unsafe_networks differ, things will suck certificate = cs.getCertificate(cert.Version1) } if certificate == nil { - panic("No certificate available to reconfigure the firewall") + return nil, errors.New("no certificate available to reconfigure the firewall") } fw := NewFirewall( @@ -207,7 +238,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), certificate, - //TODO: max_connections ) fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) @@ -314,6 +344,18 @@ func (f *Firewall) GetRuleHashes() string { return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) } +func (f *Firewall) SetSNATAddressFromInterface(i *Interface) { + //address-mutation-avoidance is done inside Interface, the firewall doesn't need to care + //todo should snatted conntracks get expired out? Probably not needed until if/when we allow reload + f.snatAddr = i.inside.SNATAddress().Addr() + f.unsafeIPv4Origin = i.inside.UnsafeIPv4OriginAddress().Addr() +} + +func (f *Firewall) ShouldUnSNAT(fp *firewall.Packet) bool { + // f.snatAddr is only valid if we're a snat-capable router + return f.snatAddr.IsValid() && fp.RemoteAddr == f.snatAddr +} + func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { var table string if inbound { @@ -414,50 +456,207 @@ var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses") var ErrNoMatchingRule = errors.New("no matching rule in firewall table") +func (f *Firewall) unSnat(data []byte, fp *firewall.Packet) netip.Addr { + c := f.peek(*fp) //unfortunately this needs to lock. Surely there's a better way. + if c == nil { + return netip.Addr{} + } + if !c.snat.Valid() { + return netip.Addr{} + } + oldIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort) + rewritePacket(data, fp, oldIP, c.snat.Src, ipv4DestinationPosition, destinationPortOffset) + return c.snat.SrcVpnIp +} + +func rewritePacket(data []byte, fp *firewall.Packet, oldIP netip.AddrPort, newIP netip.AddrPort, ipOffset int, portOffset int) { + //change address + copy(data[ipOffset:], newIP.Addr().AsSlice()) + recalcIPv4Checksum(data, oldIP.Addr(), newIP.Addr()) + ipHeaderLen := int(data[0]&0x0F) * 4 + + switch fp.Protocol { + case firewall.ProtoICMP: + binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], newIP.Port()) //we use the ID field as a "port" for ICMP + icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on code yet (but Linux would) + recalcICMPv4Checksum(data, icmpCode, icmpCode, oldIP.Port(), newIP.Port()) + case firewall.ProtoUDP: + dstport := ipHeaderLen + portOffset + binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port()) + recalcUDPv4Checksum(data, oldIP, newIP) + case firewall.ProtoTCP: + dstport := ipHeaderLen + portOffset + binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port()) + recalcTCPv4Checksum(data, oldIP, newIP) + } +} + +func (f *Firewall) findUsableSNATPort(fp *firewall.Packet, c *conn) error { + const halfThePorts = 0x7fff + oldPort := fp.RemotePort + conntrack := f.Conntrack + conntrack.Lock() + defer conntrack.Unlock() + for numPortsChecked := 0; numPortsChecked < halfThePorts; numPortsChecked++ { + _, ok := conntrack.Conns[*fp] + if !ok { + //yay, we can use this port + //track the snatted flow with the same expiration as the unsnatted version + c.snat.SnatPort = fp.RemotePort + conntrack.Conns[*fp] = c + return nil + } + //increment and retry. There's probably better strategies out there + fp.RemotePort++ + if fp.RemotePort < halfThePorts { + fp.RemotePort += halfThePorts // keep it ephemeral for now + } + } + + //if we made it here, we failed + fp.RemotePort = oldPort + return ErrCannotSNAT +} + +func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) error { + if !f.snatAddr.IsValid() { + return ErrCannotSNAT + } + if f.snatAddr == fp.LocalAddr { //a packet that came from UDP (incoming) should never ever have our snat address on it + return ErrSNATAddressCollision + } + if c.snat.Valid() { + //old flow: make sure it came from the right place + if !slices.Contains(hostinfo.vpnAddrs, c.snat.SrcVpnIp) { + return ErrSNATIdentityMismatch + } + fp.RemoteAddr = f.snatAddr + fp.RemotePort = c.snat.SnatPort + } else if hostinfo.vpnAddrs[0].Is6() { + //we got a new flow + c.snat = &snatInfo{ + Src: netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort), + SrcVpnIp: hostinfo.vpnAddrs[0], + } + fp.RemoteAddr = f.snatAddr + //find a new port to use, if needed + err := f.findUsableSNATPort(fp, c) + if err != nil { + c.snat = nil + return err + } + } else { + return ErrCannotSNAT + } + + newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort) + rewritePacket(data, fp, c.snat.Src, newIP, ipv4SourcePosition, sourcePortOffset) + return nil +} + +func (f *Firewall) identifyRemoteNetworkType(h *HostInfo, fp firewall.Packet) NetworkType { + if h.networks == nil { + // Simple case: Certificate has one address and no unsafe networks + if h.vpnAddrs[0] == fp.RemoteAddr { + return NetworkTypeVPN + } //else, fallthrough + } else if nwType, ok := h.networks.Lookup(fp.RemoteAddr); ok { + return nwType //will return NetworkTypeVPN or NetworkTypeUnsafe + } + + //RemoteAddr not in our networks table + if f.snatAddr.IsValid() && fp.IsIPv4() && h.HasOnlyV6Addresses() { + return NetworkTypeUnverifiedSNATPeer + } else { + return NetworkTypeInvalidPeer + } +} + +func (f *Firewall) allowRemoteNetworkType(nwType NetworkType, fp firewall.Packet) error { + switch nwType { + case NetworkTypeVPN: + return nil + case NetworkTypeInvalidPeer: + return ErrInvalidRemoteIP + case NetworkTypeVPNPeer: + //one day we might need a specialSnatMode case in here to handle routers with v4 addresses when we don't also have a v4 address? + return ErrPeerRejected // reject for now, one day this may have different FW rules + case NetworkTypeUnsafe: + return nil // nothing special, one day this may have different FW rules + case NetworkTypeUnverifiedSNATPeer: + if f.unsafeIPv4Origin.IsValid() && fp.LocalAddr == f.unsafeIPv4Origin { + return nil //the client case + } + if f.snatAddr.IsValid() { + if fp.RemoteAddr == f.snatAddr { + return ErrInvalidRemoteIP //we should never get a packet with our SNAT addr as the destination, or "from" our SNAT addr + } + return nil + } else { + return ErrInvalidRemoteIP + } + default: + return ErrUnknownNetworkType //should never happen + } +} + +func (f *Firewall) willingToHandleLocalAddr(incoming bool, fp firewall.Packet, remoteNwType NetworkType) error { + if f.routableNetworks.Contains(fp.LocalAddr) { + return nil //easy, this should handle NetworkTypeVPN in all cases, and NetworkTypeUnsafe on the router side + } + if incoming { //at least for now, reject all traffic other than what we've already decided is locally routable + return ErrInvalidLocalIP + } + + //below this line, all traffic is outgoing. Outgoing traffic to NetworkTypeUnsafe is not required to be considered inbound-routable + if remoteNwType == NetworkTypeUnsafe { + return nil + } + + return ErrInvalidLocalIP +} + // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { +func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { + table := f.OutRules + if incoming { + table = f.InRules + } + + snatmode := fp.IsIPv4() && h.HasOnlyV6Addresses() && f.snatAddr.IsValid() + if snatmode { + //if this is an IPv4 packet from a V6 only host, and we're configured to snat that kind of traffic, it must be snatted, + //so it can never be in the localcache, which lacks SNAT data + //nil out the pointer to avoid ever using it + localCache = nil + } + // Check if we spoke to this tuple, if we did then allow this packet - if f.inConns(fp, h, caPool, localCache) { + if localCache != nil { + if _, ok := localCache[fp]; ok { + return nil //packet matched the cache, we're not snatting, we can return early + } + } + c := f.inConns(fp, h, caPool, localCache) + if c != nil { + if incoming && snatmode { + return f.applySnat(pkt, &fp, c, h) + } return nil } // Make sure remote address matches nebula certificate, and determine how to treat it - if h.networks == nil { - // Simple case: Certificate has one address and no unsafe networks - if h.vpnAddrs[0] != fp.RemoteAddr { - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrInvalidRemoteIP - } - } else { - nwType, ok := h.networks.Lookup(fp.RemoteAddr) - if !ok { - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrInvalidRemoteIP - } - switch nwType { - case NetworkTypeVPN: - break // nothing special - case NetworkTypeVPNPeer: - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrPeerRejected // reject for now, one day this may have different FW rules - case NetworkTypeUnsafe: - break // nothing special, one day this may have different FW rules - default: - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrUnknownNetworkType //should never happen - } + remoteNetworkType := f.identifyRemoteNetworkType(h, fp) + if err := f.allowRemoteNetworkType(remoteNetworkType, fp); err != nil { + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return err } // Make sure we are supposed to be handling this local ip address - if !f.routableNetworks.Contains(fp.LocalAddr) { + if err := f.willingToHandleLocalAddr(incoming, fp, remoteNetworkType); err != nil { f.metrics(incoming).droppedLocalAddr.Inc(1) - return ErrInvalidLocalIP - } - - table := f.OutRules - if incoming { - table = f.InRules + return err } // We now know which firewall table to check against @@ -467,9 +666,14 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // We always want to conntrack since it is a faster operation - f.addConn(fp, incoming) + c = f.addConn(fp, incoming) - return nil + if incoming && remoteNetworkType == NetworkTypeUnverifiedSNATPeer { + return f.applySnat(pkt, &fp, c, h) + } else { + //outgoing snat is handled before this function is called + return nil + } } func (f *Firewall) metrics(incoming bool) firewallMetrics { @@ -496,12 +700,14 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { - if localCache != nil { - if _, ok := localCache[fp]; ok { - return true - } - } +func (f *Firewall) peek(fp firewall.Packet) *conn { + f.Conntrack.Lock() + c := f.Conntrack.Conns[fp] + f.Conntrack.Unlock() + return c +} + +func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn { conntrack := f.Conntrack conntrack.Lock() @@ -515,7 +721,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, if !ok { conntrack.Unlock() - return false + return nil } if c.rulesVersion != f.rulesVersion { @@ -538,7 +744,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, } delete(conntrack.Conns, fp) conntrack.Unlock() - return false + return nil } if f.l.Level >= logrus.DebugLevel { @@ -568,12 +774,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache[fp] = struct{}{} } - return true + return c } -func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { +func (f *Firewall) packetTimeout(fp firewall.Packet) time.Duration { var timeout time.Duration - c := &conn{} switch fp.Protocol { case firewall.ProtoTCP: @@ -583,7 +788,13 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { default: timeout = f.DefaultTimeout } + return timeout +} +func (f *Firewall) addConn(fp firewall.Packet, incoming bool) *conn { + c := &conn{} + + timeout := f.packetTimeout(fp) conntrack := f.Conntrack conntrack.Lock() if _, ok := conntrack.Conns[fp]; !ok { @@ -597,7 +808,9 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { c.rulesVersion = f.rulesVersion c.Expires = time.Now().Add(timeout) conntrack.Conns[fp] = c + conntrack.Unlock() + return c } // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel diff --git a/firewall/packet.go b/firewall/packet.go index 2cbfb5ea..fac3baab 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -31,6 +31,10 @@ type Packet struct { Fragment bool } +func (fp *Packet) IsIPv4() bool { + return fp.LocalAddr.Is4() && fp.RemoteAddr.Is4() +} + func (fp *Packet) Copy() *Packet { return &Packet{ LocalAddr: fp.LocalAddr, diff --git a/firewall_test.go b/firewall_test.go index a2133760..dc863319 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -213,44 +213,44 @@ func TestFirewall_Drop(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr p.RemoteAddr = netip.MustParseAddr("1.2.3.10") - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func TestFirewall_DropV6(t *testing.T) { @@ -292,44 +292,44 @@ func TestFirewall_DropV6(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr p.RemoteAddr = netip.MustParseAddr("fd12::56") - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -537,10 +537,10 @@ func TestFirewall_Drop2(t *testing.T) { cp := cert.NewCAPool() // h1/c1 lacks the proper groups - require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) + require.ErrorIs(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { @@ -618,18 +618,18 @@ func TestFirewall_Drop3(t *testing.T) { cp := cert.NewCAPool() // c1 should pass because host match - require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h2, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) - assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule) // Test a remote address match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) - require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil)) } func TestFirewall_Drop3V6(t *testing.T) { @@ -667,7 +667,7 @@ func TestFirewall_Drop3V6(t *testing.T) { fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) cp := cert.NewCAPool() require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -709,12 +709,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) @@ -723,7 +723,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) @@ -732,7 +732,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { fw.rulesVersion = oldFw.rulesVersion + 1 // Drop outbound because conntrack doesn't match new ruleset - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule) } func TestFirewall_ICMPPortBehavior(t *testing.T) { @@ -778,12 +778,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0 p.RemotePort = 0 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) //now also allow outbound - require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) }) t.Run("nonzero ports", func(t *testing.T) { @@ -791,12 +791,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0xabcd p.RemotePort = 0x1234 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) //now also allow outbound - require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) }) }) @@ -808,12 +808,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0 p.RemotePort = 0 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) //now also allow outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) }) t.Run("nonzero ports, still blocked", func(t *testing.T) { @@ -821,12 +821,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0xabcd p.RemotePort = 0x1234 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) //now also allow outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) }) t.Run("nonzero, matching ports, still blocked", func(t *testing.T) { @@ -834,12 +834,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 80 p.RemotePort = 80 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) //now also allow outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) }) }) t.Run("Any proto, any port", func(t *testing.T) { @@ -851,12 +851,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0 p.RemotePort = 0 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) //now also allow outbound - require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) }) t.Run("nonzero ports, allowed", func(t *testing.T) { @@ -865,15 +865,15 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0xabcd p.RemotePort = 0x1234 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) //now also allow outbound - require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) //different ID is blocked p.RemotePort++ - require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + require.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) }) }) @@ -922,7 +922,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { Protocol: firewall.ProtoUDP, Fragment: false, } - assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrInvalidRemoteIP) } func BenchmarkLookup(b *testing.B) { @@ -1042,7 +1042,7 @@ func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition c := &dummyCert{} - cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) + cs, err := newCertState(l, cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) require.NoError(t, err) conf := config.NewC(l) @@ -1336,7 +1336,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) { t.Helper() cp := cert.NewCAPool() resetConntrack(fw) - err := fw.Drop(c.p, true, c.h, cp, nil) + err := fw.Drop(c.p, nil, true, c.h, cp, nil) if c.err == nil { require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr) } else { @@ -1344,7 +1344,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) { } } -func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase { +func buildHostinfo(setup testsetup, theirPrefixes ...netip.Prefix) *HostInfo { c1 := dummyCert{ name: "host1", networks: theirPrefixes, @@ -1364,6 +1364,11 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te h.vpnAddrs[i] = theirPrefixes[i].Addr() } h.buildNetworks(setup.myVpnNetworksTable, &c1) + return &h +} + +func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase { + h := buildHostinfo(setup, theirPrefixes...) p := firewall.Packet{ LocalAddr: setup.c.Networks()[0].Addr(), //todo? RemoteAddr: theirPrefixes[0].Addr(), @@ -1373,9 +1378,9 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te Fragment: false, } return testcase{ - h: &h, + h: h, p: p, - c: &c1, + c: h.ConnectionState.peerCert.Certificate, err: err, } } @@ -1397,6 +1402,19 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse return newSetupFromCert(t, l, c) } +func newSnatSetup(t *testing.T, l *logrus.Logger, myPrefix netip.Prefix, snatAddr netip.Addr) testsetup { + c := dummyCert{ + name: "me", + networks: []netip.Prefix{myPrefix}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + out := newSetupFromCert(t, l, c) + out.fw.snatAddr = snatAddr + return out +} + func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { myVpnNetworksTable := new(bart.Lite) for _, prefix := range c.Networks() { @@ -1532,3 +1550,59 @@ func resetConntrack(fw *Firewall) { fw.Conntrack.Conns = map[firewall.Packet]*conn{} fw.Conntrack.Unlock() } + +func TestFirewall_SNAT(t *testing.T) { + t.Parallel() + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + cp := cert.NewCAPool() + myPrefix := netip.MustParsePrefix("1.1.1.1/8") + + MyCert := dummyCert{ + name: "me", + networks: []netip.Prefix{myPrefix}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + theirPrefix := netip.MustParsePrefix("1.2.2.2/8") + snatAddr := netip.MustParseAddr("169.254.55.96") + t.Run("allow inbound all matching", func(t *testing.T) { + t.Parallel() + myCert := MyCert.Copy() + setup := newSnatSetup(t, l, myPrefix, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + resetConntrack(setup.fw) + h := buildHostinfo(setup, theirPrefix) + p := firewall.Packet{ + LocalAddr: setup.c.Networks()[0].Addr(), //todo? + RemoteAddr: h.vpnAddrs[0], + LocalPort: 10, + RemotePort: 90, + Protocol: firewall.ProtoUDP, + Fragment: false, + } + require.NoError(t, setup.fw.Drop(p, nil, true, h, cp, nil)) + }) + //t.Run("allow inbound unsafe route", func(t *testing.T) { + // t.Parallel() + // unsafePrefix := netip.MustParsePrefix("192.168.0.0/24") + // c := dummyCert{ + // name: "me", + // networks: []netip.Prefix{myPrefix}, + // unsafeNetworks: []netip.Prefix{unsafePrefix}, + // groups: []string{"default-group"}, + // issuer: "signer-shasum", + // } + // unsafeSetup := newSetupFromCert(t, l, c) + // tc := buildTestCase(unsafeSetup, nil, twoPrefixes...) + // tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3") + // tc.err = ErrNoMatchingRule + // tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off + // require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", "")) + // tc.err = nil + // tc.Test(t, unsafeSetup.fw) //should pass + //}) +} diff --git a/hostmap.go b/hostmap.go index 7e2939e0..f50dd875 100644 --- a/hostmap.go +++ b/hostmap.go @@ -224,6 +224,9 @@ const ( NetworkTypeVPNPeer // NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks() NetworkTypeUnsafe + // NetworkTypeUnverifiedSNATPeer is used to indicate traffic we're willing to route, but never deliver to a NetworkTypeVPN + NetworkTypeUnverifiedSNATPeer + NetworkTypeInvalidPeer ) type HostInfo struct { @@ -277,6 +280,15 @@ type HostInfo struct { lastUsed time.Time } +func (i *HostInfo) HasOnlyV6Addresses() bool { + for _, vpnIp := range i.vpnAddrs { + if !vpnIp.Is6() { + return false + } + } + return true +} + type ViaSender struct { UdpAddr netip.AddrPort relayHI *HostInfo // relayHI is the host info object of the relay diff --git a/inside.go b/inside.go index 0d53f952..a4413aa0 100644 --- a/inside.go +++ b/inside.go @@ -48,9 +48,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { - hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) - }) + hostinfo, ready := f.getHostinfo(packet, fwPacket) if hostinfo == nil { f.rejectInside(packet, out, q) @@ -66,10 +64,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, packet, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) - } else { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { @@ -81,6 +78,26 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } } +func (f *Interface) getHostinfo(packet []byte, fwPacket *firewall.Packet) (*HostInfo, bool) { + if f.firewall.ShouldUnSNAT(fwPacket) { + //unsnat packet re-writing also happens here, would be nice to not, + //but we need to do the unsnat lookup to find the hostinfo so we can run the firewall checks + destVpnAddr := f.firewall.unSnat(packet, fwPacket) + if destVpnAddr.IsValid() { + //because this was a snatted packet, we know it has an on-overlay destination, so no routing should be required. + return f.getOrHandshakeNoRouting(destVpnAddr, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + } else { + return nil, false + } + } else { //if we didn't need to unsnat + return f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + } +} + func (f *Interface) rejectInside(packet []byte, out []byte, q int) { if !f.firewall.InSendReject { return @@ -218,7 +235,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) + dropReason := f.firewall.Drop(*fp, p, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). diff --git a/interface.go b/interface.go index 61b1f228..0acbc147 100644 --- a/interface.go +++ b/interface.go @@ -248,6 +248,7 @@ func (f *Interface) activate() { f.inside.Close() f.l.Fatal(err) } + f.firewall.SetSNATAddressFromInterface(f) } func (f *Interface) run() { @@ -344,6 +345,7 @@ func (f *Interface) reloadFirewall(c *config.C) { f.l.WithError(err).Error("Error while creating firewall during reload") return } + fw.SetSNATAddressFromInterface(f) oldFw := f.firewall conntrack := oldFw.Conntrack diff --git a/main.go b/main.go index 74979417..975bdebf 100644 --- a/main.go +++ b/main.go @@ -131,7 +131,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg deviceFactory = overlay.NewDeviceFromConfig } - tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines) + cs := pki.getCertState() + tun, err = deviceFactory(c, l, cs.myVpnNetworks, cs.GetDefaultCertificate().UnsafeNetworks(), routines) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } diff --git a/outside.go b/outside.go index b2cbf123..e78b1cb7 100644 --- a/outside.go +++ b/outside.go @@ -514,7 +514,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, out, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in diff --git a/overlay/device.go b/overlay/device.go index b6077aba..0f2f44c2 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -11,6 +11,9 @@ type Device interface { io.ReadWriteCloser Activate() error Networks() []netip.Prefix + UnsafeNetworks() []netip.Prefix + UnsafeIPv4OriginAddress() netip.Prefix + SNATAddress() netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways SupportsMultiqueue() bool diff --git a/overlay/tun.go b/overlay/tun.go index e0bf69f6..d18ce123 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -1,7 +1,9 @@ package overlay import ( + "crypto/rand" "fmt" + "io" "net" "net/netip" @@ -22,22 +24,22 @@ func (e *NameError) Error() string { } // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) return tun, nil default: - return newTun(c, l, vpnNetworks, routines > 1) + return newTun(c, l, vpnNetworks, unsafeNetworks, routines > 1) } } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { - return newTunFromFd(c, l, *fd, vpnNetworks) + return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) { + return newTunFromFd(c, l, *fd, vpnNetworks, unsafeNetworks) } } @@ -129,3 +131,85 @@ func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, er return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest) } + +// genLinkLocal generates a random IPv4 link-local address. +// If randomizer is nil, it uses rand.Reader to find two random bytes +func genLinkLocal(randomizer io.Reader) netip.Prefix { + if randomizer == nil { + randomizer = rand.Reader + } + octets := []byte{169, 254, 0, 0} + _, _ = randomizer.Read(octets[2:4]) + return coerceLinkLocal(octets) +} + +func coerceLinkLocal(octets []byte) netip.Prefix { + if octets[3] == 0 { + octets[3] = 1 //please no .0 addresses + } else if octets[2] == 255 && octets[3] == 255 { + octets[3] = 254 //please no broadcast addresses + } + out, _ := netip.AddrFromSlice(octets) + return netip.PrefixFrom(out, 32) +} + +// prepareUnsafeOriginAddr provides the IPv4 address used on IPv6-only clients that need to access IPv4 unsafe routes +func prepareUnsafeOriginAddr(d Device, l *logrus.Logger, c *config.C, routes []Route) netip.Prefix { + if !d.Networks()[0].Addr().Is6() { + return netip.Prefix{} //if we have an IPv4 assignment within the overlay, we don't need an unsafe origin address + } + + needed := false + for _, route := range routes { //or if we have a route defined into an IPv4 range + if route.Cidr.Addr().Is4() { + needed = true //todo should this only apply to unsafe routes? almost certainly + break + } + } + if !needed { + return netip.Prefix{} + } + + //todo better config name for sure + if a := c.GetString("tun.unsafe_origin_address_for_4over6", ""); a != "" { + out, err := netip.ParseAddr(a) + if err != nil { + l.WithField("value", a).WithError(err).Warn("failed to parse tun.unsafe_origin_address_for_4over6, will use a random value") + } else if !out.Is4() || !out.IsLinkLocalUnicast() { + l.WithField("value", out).Warn("tun.unsafe_origin_address_for_4over6 must be an IPv4 address") + } else if out.IsValid() { + return netip.PrefixFrom(out, 32) + } + } + return genLinkLocal(nil) +} + +// prepareSnatAddr provides the address that an IPv6-only unsafe router should use to SNAT traffic before handing it to the operating system +func prepareSnatAddr(d Device, l *logrus.Logger, c *config.C) netip.Prefix { + if !d.Networks()[0].Addr().Is6() { + return netip.Prefix{} //if we have an IPv4 assignment within the overlay, we don't need a snat address + } + + needed := false + for _, un := range d.UnsafeNetworks() { //if we are an unsafe router for an IPv4 range + if un.Addr().Is4() { + needed = true + break + } + } + if !needed { + return netip.Prefix{} + } + + if a := c.GetString("tun.snat_address_for_4over6", ""); a != "" { + out, err := netip.ParseAddr(a) + if err != nil { + l.WithField("value", a).WithError(err).Warn("failed to parse tun.snat_address_for_4over6, will use a random value") + } else if !out.Is4() || !out.IsLinkLocalUnicast() { + l.WithField("value", out).Warn("tun.snat_address_for_4over6 must be an IPv4 address") + } else if out.IsValid() { + return netip.PrefixFrom(out, 32) + } + } + return genLinkLocal(nil) +} diff --git a/overlay/tun_android.go b/overlay/tun_android.go index eddef882..c9213cc7 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -19,14 +19,16 @@ import ( type tun struct { io.ReadWriteCloser - fd int - vpnNetworks []netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + fd int + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -35,6 +37,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net ReadWriteCloser: file, fd: deviceFd, vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, l: l, } @@ -53,7 +56,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } @@ -76,6 +79,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err @@ -91,6 +96,18 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + +func (t *tun) SNATAddress() netip.Prefix { + return netip.Prefix{} +} + func (t *tun) Name() string { return "android" } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 128c2001..1911564a 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -24,13 +24,15 @@ import ( type tun struct { io.ReadWriteCloser - Device string - vpnNetworks []netip.Prefix - DefaultMTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - linkAddr *netroute.LinkAddr - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + DefaultMTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + linkAddr *netroute.LinkAddr + l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte @@ -79,7 +81,7 @@ type ifreqAlias6 struct { Lifetime addrLifetime } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -127,6 +129,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( ReadWriteCloser: os.NewFile(uintptr(fd), ""), Device: name, vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -153,7 +156,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -213,6 +216,11 @@ func (t *tun) Activate() error { } } } + if t.unsafeIPv4Origin.IsValid() && t.unsafeIPv4Origin.Addr().Is4() { + if err = t.activate4(t.unsafeIPv4Origin); err != nil { + return err + } + } // Run the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING @@ -314,6 +322,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + if initial { + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) + } + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err @@ -545,6 +557,18 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + +func (t *tun) SNATAddress() netip.Prefix { + return netip.Prefix{} +} + func (t *tun) Name() string { return t.Device } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index aa3dddaf..9ade55ac 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -52,6 +52,17 @@ func (t *disabledTun) Networks() []netip.Prefix { return t.vpnNetworks } +func (*disabledTun) UnsafeNetworks() []netip.Prefix { + return nil +} +func (*disabledTun) SNATAddress() netip.Prefix { + return netip.Prefix{} +} + +func (*disabledTun) UnsafeIPv4OriginAddress() netip.Prefix { + return netip.Prefix{} +} + func (*disabledTun) Name() string { return "disabled" } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 2f65b3a4..e0f21769 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -86,14 +86,16 @@ type ifreqAlias6 struct { } type tun struct { - Device string - vpnNetworks []netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - linkAddr *netroute.LinkAddr - l *logrus.Logger - devFd int + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + linkAddr *netroute.LinkAddr + l *logrus.Logger + devFd int } func (t *tun) Read(to []byte) (int, error) { @@ -199,11 +201,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var fd int var err error @@ -270,11 +272,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } t := &tun{ - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, - devFd: fd, + Device: deviceName, + vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + devFd: fd, } err = t.reload(c, true) @@ -410,6 +413,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + if initial { + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) + } + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err @@ -446,6 +453,18 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + +func (t *tun) SNATAddress() netip.Prefix { + return netip.Prefix{} +} + func (t *tun) Name() string { return t.Device } diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 0ce01df8..50ae4546 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -22,20 +22,23 @@ import ( type tun struct { io.ReadWriteCloser - vpnNetworks []netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, ReadWriteCloser: &tunReadCloser{f: file}, l: l, } @@ -69,6 +72,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err @@ -147,6 +152,18 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + +func (t *tun) SNATAddress() netip.Prefix { + return netip.Prefix{} +} + func (t *tun) Name() string { return "iOS" } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 7e4aa418..1b70e8b3 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -26,14 +26,15 @@ import ( type tun struct { io.ReadWriteCloser - fd int - Device string - vpnNetworks []netip.Prefix - MaxMTU int - DefaultMTU int - TXQueueLen int - deviceIndex int - ioctlFd uintptr + fd int + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + MaxMTU int + DefaultMTU int + TXQueueLen int + deviceIndex int + ioctlFd uintptr Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] @@ -46,6 +47,9 @@ type tun struct { routesFromSystem map[netip.Prefix]routing.Gateways routesFromSystemLock sync.Mutex + snatAddr netip.Prefix + unsafeIPv4Origin netip.Prefix + l *logrus.Logger } @@ -53,6 +57,18 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + +func (t *tun) SNATAddress() netip.Prefix { + return t.snatAddr +} + type ifReq struct { Name [16]byte Flags uint16 @@ -71,10 +87,10 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks) if err != nil { return nil, err } @@ -84,7 +100,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -123,7 +139,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks) if err != nil { return nil, err } @@ -133,11 +149,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), @@ -170,6 +187,11 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + if initial { + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) //todo MUST be different from t.snatAddr! + t.snatAddr = prepareSnatAddr(t, t.l, c) + } + routeTree, err := makeRouteTree(t.l, routes, true) if err != nil { return err @@ -313,6 +335,17 @@ func (t *tun) addIPs(link netlink.Link) error { } } + if t.unsafeIPv4Origin.IsValid() { + newAddrs = append(newAddrs, &netlink.Addr{ + IPNet: &net.IPNet{ + IP: t.unsafeIPv4Origin.Addr().AsSlice(), + Mask: net.CIDRMask(t.unsafeIPv4Origin.Bits(), t.unsafeIPv4Origin.Addr().BitLen()), + }, + Label: t.unsafeIPv4Origin.Addr().Zone(), + }) + t.l.WithField("address", t.unsafeIPv4Origin).Info("Adding origin address for IPv4 unsafe_routes") + } + //add all new addresses for i := range newAddrs { //AddrReplace still adds new IPs, but if their properties change it will change them as well @@ -400,7 +433,13 @@ func (t *tun) Activate() error { //set route MTU for i := range t.vpnNetworks { if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil { - return fmt.Errorf("failed to set default route MTU: %w", err) + return fmt.Errorf("failed to set default route MTU for %s: %w", t.vpnNetworks[i], err) + } + } + + if t.unsafeIPv4Origin.IsValid() { + if err = t.setDefaultRoute(t.unsafeIPv4Origin); err != nil { + return fmt.Errorf("failed to set default route MTU for %s: %w", t.unsafeIPv4Origin, err) } } @@ -427,6 +466,23 @@ func (t *tun) setMTU() { } } +func (t *tun) setSnatRoute() error { + dr := &net.IPNet{ + IP: t.snatAddr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(t.snatAddr.Bits(), t.snatAddr.Addr().BitLen()), + } + + nr := netlink.Route{ + LinkIndex: t.deviceIndex, + Dst: dr, + Scope: unix.RT_SCOPE_LINK, + //Protocol: unix.RTPROT_KERNEL, + Table: unix.RT_TABLE_MAIN, + Type: unix.RTN_UNICAST, + } + return netlink.RouteReplace(&nr) +} + func (t *tun) setDefaultRoute(cidr netip.Prefix) error { dr := &net.IPNet{ IP: cidr.Masked().Addr().AsSlice(), @@ -503,6 +559,13 @@ func (t *tun) addRoutes(logErrors bool) error { } } + if t.snatAddr.IsValid() { + //at least for Linux, we need to set a return route for the SNATted traffic in order to satisfy the reverse-path filter, + //and to help the kernel deliver our reply traffic to the tun device. + //however, it is important that we do not actually /assign/ the SNAT address, + //since link-local addresses will not be routed between interfaces without significant trickery. + return t.setSnatRoute() + } return nil } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 2986c895..448bede2 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -58,23 +58,25 @@ type addrLifetime struct { } type tun struct { - Device string - vpnNetworks []netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger - f *os.File - fd int + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger + f *os.File + fd int } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") @@ -350,6 +352,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + if initial { + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) + } + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err @@ -386,6 +392,18 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + +func (t *tun) SNATAddress() netip.Prefix { + return netip.Prefix{} +} + func (t *tun) Name() string { return t.Device } diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 9209b795..bab929d0 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -49,25 +49,27 @@ type ifreq struct { } type tun struct { - Device string - vpnNetworks []netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger - f *os.File - fd int + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger + f *os.File + fd int // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in openbsd") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") @@ -89,12 +91,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } t := &tun{ - f: os.NewFile(uintptr(fd), ""), - fd: fd, - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + f: os.NewFile(uintptr(fd), ""), + fd: fd, + Device: deviceName, + vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, } err = t.reload(c, true) @@ -270,6 +273,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + if initial { + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) + } + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err @@ -306,6 +313,18 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + +func (t *tun) SNATAddress() netip.Prefix { + return netip.Prefix{} +} + func (t *tun) Name() string { return t.Device } diff --git a/overlay/tun_snat_test.go b/overlay/tun_snat_test.go new file mode 100644 index 00000000..b340eb09 --- /dev/null +++ b/overlay/tun_snat_test.go @@ -0,0 +1,179 @@ +package overlay + +import ( + "io" + "net/netip" + "testing" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockDevice is a minimal Device implementation for testing prepareUnsafeOriginAddr. +type mockDevice struct { + networks []netip.Prefix + unsafeNetworks []netip.Prefix + snatAddr netip.Prefix + unsafeSnatAddr netip.Prefix +} + +func (d *mockDevice) Read([]byte) (int, error) { return 0, nil } +func (d *mockDevice) Write([]byte) (int, error) { return 0, nil } +func (d *mockDevice) Close() error { return nil } +func (d *mockDevice) Activate() error { return nil } +func (d *mockDevice) Networks() []netip.Prefix { return d.networks } +func (d *mockDevice) UnsafeNetworks() []netip.Prefix { return d.unsafeNetworks } +func (d *mockDevice) SNATAddress() netip.Prefix { return d.snatAddr } +func (d *mockDevice) UnsafeIPv4OriginAddress() netip.Prefix { return d.unsafeSnatAddr } +func (d *mockDevice) Name() string { return "mock" } +func (d *mockDevice) RoutesFor(netip.Addr) routing.Gateways { return routing.Gateways{} } +func (d *mockDevice) SupportsMultiqueue() bool { return false } +func (d *mockDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, nil } + +func TestPrepareSnatAddr_V4Primary_NoSnat(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // If the device has an IPv4 primary address, no SNAT needed + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + } + result := prepareUnsafeOriginAddr(d, l, c, nil) + assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr when device has IPv4 primary") +} + +func TestPrepareSnatAddr_V6Primary_NoUnsafeOrRoutes(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // IPv6 primary but no unsafe networks or IPv4 routes + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + } + result := prepareUnsafeOriginAddr(d, l, c, nil) + assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr without IPv4 unsafe networks or routes") +} + +func TestPrepareSnatAddr_V6Primary_WithV4Unsafe(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // IPv6 primary with IPv4 unsafe network -> should get SNAT addr + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + result := prepareSnatAddr(d, l, c) + require.True(t, result.IsValid(), "should assign SNAT addr") + assert.True(t, result.Addr().Is4(), "SNAT addr should be IPv4") + assert.True(t, result.Addr().IsLinkLocalUnicast(), "SNAT addr should be link-local") + assert.Equal(t, 32, result.Bits(), "SNAT addr should be /32") + + result = prepareUnsafeOriginAddr(d, l, c, nil) + require.False(t, result.IsValid(), "no routes = no origin addr needed") +} + +func TestPrepareUnsafeOriginAddr_V6Primary_WithV4Route(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // IPv6 primary with IPv4 route -> should get SNAT addr + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + } + routes := []Route{ + {Cidr: netip.MustParsePrefix("10.0.0.0/8")}, + } + result := prepareUnsafeOriginAddr(d, l, c, routes) + require.True(t, result.IsValid(), "should assign SNAT addr when IPv4 route exists") + assert.True(t, result.Addr().Is4()) + assert.True(t, result.Addr().IsLinkLocalUnicast()) + + result = prepareSnatAddr(d, l, c) + require.False(t, result.IsValid(), "no UnsafeNetworks = no snat addr needed") +} + +func TestPrepareSnatAddr_V6Primary_V6UnsafeOnly(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // IPv6 primary with only IPv6 unsafe network -> no SNAT needed + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("fd01::/64")}, + } + result := prepareUnsafeOriginAddr(d, l, c, nil) + assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr for IPv6-only unsafe networks") +} + +func TestPrepareSnatAddr_ManualAddress(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + c.Settings["tun"] = map[string]any{ + "snat_address_for_4over6": "169.254.42.42", + } + + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + result := prepareSnatAddr(d, l, c) + require.True(t, result.IsValid()) + assert.Equal(t, netip.MustParseAddr("169.254.42.42"), result.Addr()) + assert.Equal(t, 32, result.Bits()) +} + +func TestPrepareSnatAddr_InvalidManualAddress_Fallback(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + c.Settings["tun"] = map[string]any{ + "snat_address_for_4over6": "not-an-ip", + } + + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + result := prepareSnatAddr(d, l, c) + // Should fall back to auto-assignment + require.True(t, result.IsValid(), "should fall back to auto-assigned address") + assert.True(t, result.Addr().Is4()) + assert.True(t, result.Addr().IsLinkLocalUnicast()) +} + +func TestPrepareSnatAddr_AutoGenerated_Range(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + + // Generate several addresses and verify they're all in the expected range + for i := 0; i < 100; i++ { + result := prepareSnatAddr(d, l, c) + require.True(t, result.IsValid()) + addr := result.Addr() + octets := addr.As4() + assert.Equal(t, byte(169), octets[0], "first octet should be 169") + assert.Equal(t, byte(254), octets[1], "second octet should be 254") + // Should not have .0 in the last octet + assert.NotEqual(t, byte(0), octets[3], "last octet should not be 0") + // Should not be 169.254.255.255 (broadcast) + if octets[2] == 255 { + assert.NotEqual(t, byte(255), octets[3], "should not be broadcast address") + } + } +} diff --git a/overlay/tun_test.go b/overlay/tun_test.go new file mode 100644 index 00000000..1e247aad --- /dev/null +++ b/overlay/tun_test.go @@ -0,0 +1,37 @@ +package overlay + +import ( + "bytes" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLinkLocal(t *testing.T) { + r := bytes.NewReader([]byte{42, 99}) + result := genLinkLocal(r) + assert.Equal(t, netip.MustParsePrefix("169.254.42.99/32"), result, "genLinkLocal with a deterministic randomizer") + + result = genLinkLocal(nil) + assert.True(t, result.IsValid(), "genLinkLocal with nil randomizer should be valid") + assert.True(t, result.Addr().IsLinkLocalUnicast(), "genLinkLocal with nil randomizer should be link-local") + + result = coerceLinkLocal([]byte{169, 254, 100, 50}) + assert.Equal(t, netip.MustParsePrefix("169.254.100.50/32"), result, "coerceLinkLocal should pass through normal values") + + result = coerceLinkLocal([]byte{169, 254, 0, 0}) + assert.Equal(t, netip.MustParsePrefix("169.254.0.1/32"), result, "coerceLinkLocal should bump .0 last octet to .1") + + result = coerceLinkLocal([]byte{169, 254, 255, 255}) + assert.Equal(t, netip.MustParsePrefix("169.254.255.254/32"), result, "coerceLinkLocal should bump broadcast 255.255 to 255.254") + + result = coerceLinkLocal([]byte{169, 254, 0, 1}) + assert.Equal(t, netip.MustParsePrefix("169.254.0.1/32"), result, "coerceLinkLocal should leave .1 last octet unchanged") + + result = coerceLinkLocal([]byte{169, 254, 255, 254}) + assert.Equal(t, netip.MustParsePrefix("169.254.255.254/32"), result, "coerceLinkLocal should leave 255.254 unchanged") + + result = coerceLinkLocal([]byte{169, 254, 255, 100}) + assert.Equal(t, netip.MustParsePrefix("169.254.255.100/32"), result, "coerceLinkLocal should leave 255.100 unchanged") +} diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3477de3d..234d9336 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -17,18 +17,21 @@ import ( ) type TestTun struct { - Device string - vpnNetworks []netip.Prefix - Routes []Route - routeTree *bart.Table[routing.Gateways] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + snatAddr netip.Prefix + unsafeIPv4Origin netip.Prefix + Routes []Route + routeTree *bart.Table[routing.Gateways] + l *logrus.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*TestTun, error) { _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err @@ -38,18 +41,22 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( return nil, err } - return &TestTun{ - Device: c.GetString("tun.dev", ""), - vpnNetworks: vpnNetworks, - Routes: routes, - routeTree: routeTree, - l: l, - rxPackets: make(chan []byte, 10), - TxPackets: make(chan []byte, 10), - }, nil + tt := &TestTun{ + Device: c.GetString("tun.dev", ""), + vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, + Routes: routes, + routeTree: routeTree, + l: l, + rxPackets: make(chan []byte, 10), + TxPackets: make(chan []byte, 10), + } + tt.unsafeIPv4Origin = prepareUnsafeOriginAddr(tt, l, c, routes) + tt.snatAddr = prepareSnatAddr(tt, tt.l, c) + return tt, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -139,3 +146,15 @@ func (t *TestTun) SupportsMultiqueue() bool { func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented") } + +func (t *TestTun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *TestTun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + +func (t *TestTun) SNATAddress() netip.Prefix { + return t.snatAddr +} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 223eabee..303b61d8 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -28,21 +28,23 @@ import ( const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { - Device string - vpnNetworks []netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger tun *wintun.NativeTun } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*winTun, error) { err := checkWinTunExists() if err != nil { return nil, fmt.Errorf("can not load the wintun driver: %w", err) @@ -55,10 +57,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } t := &winTun{ - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + Device: deviceName, + vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, } err = t.reload(c, true) @@ -102,6 +105,10 @@ func (t *winTun) reload(c *config.C, initial bool) error { return nil } + if initial { + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) + } + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err @@ -132,7 +139,12 @@ func (t *winTun) reload(c *config.C, initial bool) error { func (t *winTun) Activate() error { luid := winipcfg.LUID(t.tun.LUID()) - err := luid.SetIPAddresses(t.vpnNetworks) + prefixes := t.vpnNetworks + if t.unsafeIPv4Origin.IsValid() { + prefixes = append(prefixes, t.unsafeIPv4Origin) + } + + err := luid.SetIPAddresses(prefixes) if err != nil { return fmt.Errorf("failed to set address: %w", err) } @@ -225,6 +237,18 @@ func (t *winTun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *winTun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *winTun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + +func (t *winTun) SNATAddress() netip.Prefix { + return netip.Prefix{} +} + func (t *winTun) Name() string { return t.Device } diff --git a/overlay/user.go b/overlay/user.go index 1f92d4e9..87eee029 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -9,7 +9,7 @@ import ( "github.com/slackhq/nebula/routing" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) { return NewUserDevice(vpnNetworks) } @@ -36,6 +36,17 @@ type UserDevice struct { inboundWriter *io.PipeWriter } +func (d *UserDevice) UnsafeNetworks() []netip.Prefix { + return nil +} + +func (d *UserDevice) SNATAddress() netip.Prefix { + return netip.Prefix{} +} +func (d *UserDevice) UnsafeIPv4OriginAddress() netip.Prefix { + return netip.Prefix{} +} + func (d *UserDevice) Activate() error { return nil } diff --git a/pki.go b/pki.go index 19869d58..3deb0fe7 100644 --- a/pki.go +++ b/pki.go @@ -91,7 +91,7 @@ func (p *PKI) reload(c *config.C, initial bool) error { } func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { - newState, err := newCertStateFromConfig(c) + newState, err := newCertStateFromConfig(c, p.l) if err != nil { return util.NewContextualError("Could not load client cert", nil, err) } @@ -102,7 +102,7 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { if currentState.v1Cert == nil { //adding certs is fine, actually. Networks-in-common confirmed in newCertState(). } else { - // did IP in cert change? if so, don't set + // did IP in cert change? if so, don't set. If we ever allow this, need to set p.firewallReloadNeeded if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { return util.NewContextualError( "Networks in new cert was different from old", @@ -158,6 +158,14 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { } } + newUN := newState.GetDefaultCertificate().UnsafeNetworks() + oldUN := currentState.GetDefaultCertificate().UnsafeNetworks() + if !slices.Equal(newUN, oldUN) { + //todo I don't love this, because other clients will see the new assignments and act on them, but we will not be able to. + //I think we need to wire this into the firewall reload. + p.l.WithFields(m{"previous": oldUN, "new": newUN}).Warning("UnsafeNetworks assignments differ. A restart is required in order for this to take effect.") + } + // Cipher cant be hot swapped so just leave it at what it was before newState.cipher = currentState.cipher @@ -260,7 +268,7 @@ func (cs *CertState) MarshalJSON() ([]byte, error) { return json.Marshal(msg) } -func newCertStateFromConfig(c *config.C) (*CertState, error) { +func newCertStateFromConfig(c *config.C, l *logrus.Logger) (*CertState, error) { var err error privPathOrPEM := c.GetString("pki.key", "") @@ -344,10 +352,33 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion) } - return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey) + return newCertState(l, initiatingVersion, v1, v2, isPkcs11, curve, rawKey) } -func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { +func compareUnsafeNetworksAcrossCertVersions(v1, v2 cert.Certificate) error { + if v1 == nil || v2 == nil { + return nil //can't be a problem if we don't have one of the kinds of cert + } + + v4UnsafeNets := 0 + for _, n := range v2.UnsafeNetworks() { + if n.Addr().Is6() { + continue // V1 certs can't have IPv6 unsafe networks + } else { + v4UnsafeNets++ + } + if !slices.Contains(v1.UnsafeNetworks(), n) { + return errors.New("UnsafeNetworks mismatch") + } + } + if len(v1.UnsafeNetworks()) != v4UnsafeNets { + return errors.New("UnsafeNetworks mismatch") + } + + return nil +} + +func newCertState(l *logrus.Logger, dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { cs := CertState{ privateKey: privateKey, pkcs11Backed: pkcs11backed, @@ -370,6 +401,12 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p } cs.initiatingVersion = dv + + warn := compareUnsafeNetworksAcrossCertVersions(v1, v2) + if warn != nil { + l.WithFields(m{"UnsafeNetworksV1": v1.UnsafeNetworks(), "UnsafeNetworksV2": v2.UnsafeNetworks()}). + Warning("the IPv4 UnsafeNetworks assigned in the V1 certificate do not match the ones in V2") + } } if v1 != nil { diff --git a/snat.go b/snat.go new file mode 100644 index 00000000..3164b641 --- /dev/null +++ b/snat.go @@ -0,0 +1,91 @@ +package nebula + +import ( + "encoding/binary" + "net/netip" +) + +func recalcIPv4Checksum(data []byte, oldSrcIP netip.Addr, newSrcIP netip.Addr) { + oldChecksum := binary.BigEndian.Uint16(data[10:12]) + //because of how checksums work, we can re-use this function + checksum := calcNewTransportChecksum(oldChecksum, oldSrcIP, 0, newSrcIP, 0) + binary.BigEndian.PutUint16(data[10:12], checksum) +} + +func calcNewTransportChecksum(oldChecksum uint16, oldSrcIP netip.Addr, oldSrcPort uint16, newSrcIP netip.Addr, newSrcPort uint16) uint16 { + oldIP := binary.BigEndian.Uint32(oldSrcIP.AsSlice()) + newIP := binary.BigEndian.Uint32(newSrcIP.AsSlice()) + + // Start with inverted checksum + checksum := uint32(^oldChecksum) + + // Subtract old IP (as two 16-bit words) + checksum += uint32(^uint16(oldIP >> 16)) + checksum += uint32(^uint16(oldIP & 0xFFFF)) + + // Subtract old port + checksum += uint32(^oldSrcPort) + + // Add new IP (as two 16-bit words) + checksum += uint32(newIP >> 16) + checksum += uint32(newIP & 0xFFFF) + + // Add new port + checksum += uint32(newSrcPort) + + // Fold carries + for checksum > 0xFFFF { + checksum = (checksum & 0xFFFF) + (checksum >> 16) + } + + // Return ones' complement + return ^uint16(checksum) +} + +func recalcV4TransportChecksum(offsetInsideHeader int, data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) { + ipHeaderOffset := int(data[0]&0x0F) * 4 + offset := ipHeaderOffset + offsetInsideHeader + oldcsum := binary.BigEndian.Uint16(data[offset : offset+2]) + checksum := calcNewTransportChecksum(oldcsum, oldSrcIP.Addr(), oldSrcIP.Port(), newSrcIP.Addr(), newSrcIP.Port()) + binary.BigEndian.PutUint16(data[offset:offset+2], checksum) +} + +func recalcUDPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) { + const offsetInsideHeader = 6 + recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP) +} + +func recalcTCPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) { + const offsetInsideHeader = 16 + recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP) +} + +func calcNewICMPChecksum(oldChecksum uint16, oldCode uint16, newCode uint16, oldID uint16, newID uint16) uint16 { + // Start with inverted checksum + checksum := uint32(^oldChecksum) + + // Subtract old stuff + checksum += uint32(^oldCode) + checksum += uint32(^oldID) + + // Add new stuff + checksum += uint32(newCode) + checksum += uint32(newID) + + // Fold carries + for checksum > 0xFFFF { + checksum = (checksum & 0xFFFF) + (checksum >> 16) + } + + // Return ones' complement + return ^uint16(checksum) +} + +func recalcICMPv4Checksum(data []byte, oldCode uint16, newCode uint16, oldID uint16, newID uint16) { + const offsetInsideHeader = 2 + ipHeaderOffset := int(data[0]&0x0F) * 4 + offset := ipHeaderOffset + offsetInsideHeader + oldChecksum := binary.BigEndian.Uint16(data[offset : offset+2]) + checksum := calcNewICMPChecksum(oldChecksum, oldCode, newCode, oldID, newID) + binary.BigEndian.PutUint16(data[offset:offset+2], checksum) +} diff --git a/snat_test.go b/snat_test.go new file mode 100644 index 00000000..1bba83ce --- /dev/null +++ b/snat_test.go @@ -0,0 +1,1310 @@ +package nebula + +import ( + "encoding/binary" + "net/netip" + "slices" + "testing" + "time" + + "github.com/gaissmai/bart" + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/firewall" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Canonical test packets with all checksums computed from scratch by +// /tmp/gen_canonical.go. Tests feed these into the production rewrite +// functions and compare byte-for-byte against expected outputs. + +// canonicalUDP: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="hello world" +var canonicalUDP = []byte{ + 0x45, 0x00, 0x00, 0x27, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xe8, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x13, 0x71, 0xc6, 0x68, 0x65, 0x6c, 0x6c, + 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, +} + +// canonicalTCP: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=TCP payload="GET / HTTP/1.1" +var canonicalTCP = []byte{ + 0x45, 0x00, 0x00, 0x36, 0x12, 0x34, 0x40, 0x00, 0x40, 0x06, 0x5c, 0xe4, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x12, 0x34, 0x56, 0x78, 0x00, 0x00, 0x00, 0x00, + 0x50, 0x02, 0xff, 0xff, 0x86, 0x68, 0x00, 0x00, 0x47, 0x45, 0x54, 0x20, 0x2f, 0x20, 0x48, 0x54, + 0x54, 0x50, 0x2f, 0x31, 0x2e, 0x31, +} + +// canonicalICMP: src=10.0.0.1 dst=192.168.1.1 proto=ICMP echo, id=0x1234 seq=1 +var canonicalICMP = []byte{ + 0x45, 0x00, 0x00, 0x1c, 0x12, 0x34, 0x40, 0x00, 0x40, 0x01, 0x5d, 0x03, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x08, 0x00, 0xe5, 0xca, 0x12, 0x34, 0x00, 0x01, +} + +// canonicalUDPReply: src=192.168.1.1:80 dst=169.254.55.96:55555 proto=UDP payload="reply" +var canonicalUDPReply = []byte{ + 0x45, 0x00, 0x00, 0x21, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x85, 0x90, 0xc0, 0xa8, 0x01, 0x01, + 0xa9, 0xfe, 0x37, 0x60, 0x00, 0x50, 0xd9, 0x03, 0x00, 0x0d, 0x27, 0xa6, 0x72, 0x65, 0x70, 0x6c, + 0x79, +} + +// canonicalUDPTest: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="test" +var canonicalUDPTest = []byte{ + 0x45, 0x00, 0x00, 0x20, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xef, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x0c, 0x1b, 0xc9, 0x74, 0x65, 0x73, 0x74, +} + +// canonicalUDPHijack: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="hijack" +var canonicalUDPHijack = []byte{ + 0x45, 0x00, 0x00, 0x22, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xed, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x0e, 0xcd, 0x68, 0x68, 0x69, 0x6a, 0x61, + 0x63, 0x6b, +} + +// canonicalUDPBlocked: src=10.0.0.1:12345 dst=192.168.1.1:443 proto=UDP payload="blocked" +var canonicalUDPBlocked = []byte{ + 0x45, 0x00, 0x00, 0x23, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xec, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x01, 0xbb, 0x00, 0x0f, 0x60, 0xfc, 0x62, 0x6c, 0x6f, 0x63, + 0x6b, 0x65, 0x64, +} + +// canonicalUDPWrongDest: src=10.0.0.1:12345 dst=172.16.0.1:80 proto=UDP payload="wrong dest" +var canonicalUDPWrongDest = []byte{ + 0x45, 0x00, 0x00, 0x26, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x72, 0x81, 0x0a, 0x00, 0x00, 0x01, + 0xac, 0x10, 0x00, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x12, 0xf3, 0x53, 0x77, 0x72, 0x6f, 0x6e, + 0x67, 0x20, 0x64, 0x65, 0x73, 0x74, +} + +// canonicalUDPNoSnat: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="no snat" +var canonicalUDPNoSnat = []byte{ + 0x45, 0x00, 0x00, 0x23, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xec, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x0f, 0x92, 0x58, 0x6e, 0x6f, 0x20, 0x73, + 0x6e, 0x61, 0x74, +} + +// canonicalUDPV4Traffic: src=10.128.0.2:12345 dst=192.168.1.1:80 proto=UDP payload="v4 traffic" +var canonicalUDPV4Traffic = []byte{ + 0x45, 0x00, 0x00, 0x26, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0x68, 0x0a, 0x80, 0x00, 0x02, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x12, 0x2a, 0x42, 0x76, 0x34, 0x20, 0x74, + 0x72, 0x61, 0x66, 0x66, 0x69, 0x63, +} + +// canonicalUDPRoundtrip: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="roundtrip" +var canonicalUDPRoundtrip = []byte{ + 0x45, 0x00, 0x00, 0x25, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xea, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x11, 0xd4, 0xdc, 0x72, 0x6f, 0x75, 0x6e, + 0x64, 0x74, 0x72, 0x69, 0x70, +} + +// canonicalUDPSnatMe: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="snat me" +var canonicalUDPSnatMe = []byte{ + 0x45, 0x00, 0x00, 0x23, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xec, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x0f, 0xa9, 0x4c, 0x73, 0x6e, 0x61, 0x74, + 0x20, 0x6d, 0x65, +} + +// Expected outputs after rewriting — built from scratch with the post-rewrite +// addresses, so all checksums are independently correct. + +// canonicalUDPSnatted: canonicalUDP with src rewritten to 169.254.55.96:55555 +var canonicalUDPSnatted = []byte{ + 0x45, 0x00, 0x00, 0x27, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x85, 0x8a, 0xa9, 0xfe, 0x37, 0x60, + 0xc0, 0xa8, 0x01, 0x01, 0xd9, 0x03, 0x00, 0x50, 0x00, 0x13, 0xf1, 0x9d, 0x68, 0x65, 0x6c, 0x6c, + 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, +} + +// canonicalUDPReplyUnSnatted: canonicalUDPReply with dst rewritten from 169.254.55.96:55555 to 10.0.0.1:12345 +var canonicalUDPReplyUnSnatted = []byte{ + 0x45, 0x00, 0x00, 0x21, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xee, 0xc0, 0xa8, 0x01, 0x01, + 0x0a, 0x00, 0x00, 0x01, 0x00, 0x50, 0x30, 0x39, 0x00, 0x0d, 0xa7, 0xce, 0x72, 0x65, 0x70, 0x6c, + 0x79, +} + +// canonicalTCPSnatted: canonicalTCP with src rewritten to 169.254.55.96:55555 +var canonicalTCPSnatted = []byte{ + 0x45, 0x00, 0x00, 0x36, 0x12, 0x34, 0x40, 0x00, 0x40, 0x06, 0x85, 0x86, 0xa9, 0xfe, 0x37, 0x60, + 0xc0, 0xa8, 0x01, 0x01, 0xd9, 0x03, 0x00, 0x50, 0x12, 0x34, 0x56, 0x78, 0x00, 0x00, 0x00, 0x00, + 0x50, 0x02, 0xff, 0xff, 0x06, 0x40, 0x00, 0x00, 0x47, 0x45, 0x54, 0x20, 0x2f, 0x20, 0x48, 0x54, + 0x54, 0x50, 0x2f, 0x31, 0x2e, 0x31, +} + +// canonicalICMPSnatted: canonicalICMP with src rewritten to 169.254.55.96, id changed from 0x1234 to 0x5678 +var canonicalICMPSnatted = []byte{ + 0x45, 0x00, 0x00, 0x1c, 0x12, 0x34, 0x40, 0x00, 0x40, 0x01, 0x85, 0xa5, 0xa9, 0xfe, 0x37, 0x60, + 0xc0, 0xa8, 0x01, 0x01, 0x08, 0x00, 0xa1, 0x86, 0x56, 0x78, 0x00, 0x01, +} + +func TestCalcNewTransportChecksum_Identity(t *testing.T) { + // Rewriting to the same IP/port should return the same checksum + ip := netip.MustParseAddr("10.0.0.1") + result := calcNewTransportChecksum(0x1234, ip, 80, ip, 80) + assert.Equal(t, uint16(0x1234), result) +} + +func TestCalcNewTransportChecksum_VsCanonical(t *testing.T) { + srcIP := netip.MustParseAddr("10.0.0.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + // Extract the original UDP checksum from canonicalUDP (bytes 26-27) + origChecksum := binary.BigEndian.Uint16(canonicalUDP[26:28]) + + // Compute incrementally + incremental := calcNewTransportChecksum(origChecksum, srcIP, 12345, snatIP, 55555) + + // Verify it matches the checksum in the independently-computed canonicalUDPSnatted + expectedChecksum := binary.BigEndian.Uint16(canonicalUDPSnatted[26:28]) + assert.Equal(t, expectedChecksum, incremental, "incremental checksum should match canonical expected output") +} + +func TestCalcNewICMPChecksum_Identity(t *testing.T) { + // Same values in and out should be identity + result := calcNewICMPChecksum(0xABCD, 0, 0, 1234, 1234) + assert.Equal(t, uint16(0xABCD), result) +} + +func TestRewritePacket_UDP(t *testing.T) { + srcIP := netip.MustParseAddr("10.0.0.1") + dstIP := netip.MustParseAddr("192.168.1.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + pkt := slices.Clone(canonicalUDP) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: srcIP, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + + // SNAT rewrites source: IP at offset 12, port at offset 0 inside transport + oldIP := netip.AddrPortFrom(srcIP, 12345) + newIP := netip.AddrPortFrom(snatIP, 55555) + rewritePacket(pkt, &fp, oldIP, newIP, 12, 0) + + assert.Equal(t, canonicalUDPSnatted, pkt, "rewritten packet should match canonical expected output") +} + +func TestRewritePacket_UDP_UnSNAT(t *testing.T) { + snatIP := netip.MustParseAddr("169.254.55.96") + dstIP := netip.MustParseAddr("192.168.1.1") + origSrcIP := netip.MustParseAddr("10.0.0.1") + + pkt := slices.Clone(canonicalUDPReply) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: snatIP, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + + // UnSNAT rewrites destination: IP at offset 16, port at offset 2 inside transport + oldIP := netip.AddrPortFrom(snatIP, 55555) + newIP := netip.AddrPortFrom(origSrcIP, 12345) + rewritePacket(pkt, &fp, oldIP, newIP, 16, 2) + + assert.Equal(t, canonicalUDPReplyUnSnatted, pkt, "un-SNATted packet should match canonical expected output") +} + +func TestRewritePacket_TCP(t *testing.T) { + srcIP := netip.MustParseAddr("10.0.0.1") + dstIP := netip.MustParseAddr("192.168.1.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + pkt := slices.Clone(canonicalTCP) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: srcIP, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoTCP, + } + + oldIP := netip.AddrPortFrom(srcIP, 12345) + newIP := netip.AddrPortFrom(snatIP, 55555) + rewritePacket(pkt, &fp, oldIP, newIP, 12, 0) + + assert.Equal(t, canonicalTCPSnatted, pkt, "rewritten TCP packet should match canonical expected output") +} + +func TestRewritePacket_ICMP(t *testing.T) { + srcIP := netip.MustParseAddr("10.0.0.1") + dstIP := netip.MustParseAddr("192.168.1.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + pkt := slices.Clone(canonicalICMP) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: srcIP, + LocalPort: 0, + RemotePort: 0x1234, // ICMP ID used as port + Protocol: firewall.ProtoICMP, + } + + oldIP := netip.AddrPortFrom(srcIP, 0x1234) + newIP := netip.AddrPortFrom(snatIP, 0x5678) + rewritePacket(pkt, &fp, oldIP, newIP, 12, 0) + + assert.Equal(t, canonicalICMPSnatted, pkt, "rewritten ICMP packet should match canonical expected output") +} + +func TestRewritePacket_Roundtrip(t *testing.T) { + // Test that SNAT followed by unSNAT produces the original packet + srcIP := netip.MustParseAddr("10.0.0.1") + dstIP := netip.MustParseAddr("192.168.1.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + pkt := slices.Clone(canonicalUDPRoundtrip) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: srcIP, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + + // SNAT: rewrite source + oldSrc := netip.AddrPortFrom(srcIP, 12345) + newSrc := netip.AddrPortFrom(snatIP, 55555) + rewritePacket(pkt, &fp, oldSrc, newSrc, 12, 0) + + // Verify intermediate state is not the original + require.NotEqual(t, canonicalUDPRoundtrip, pkt) + + // UnSNAT: rewrite source back + rewritePacket(pkt, &fp, newSrc, oldSrc, 12, 0) + + // Packet should be byte-for-byte identical to original + assert.Equal(t, canonicalUDPRoundtrip, pkt, "packet should be identical after roundtrip SNAT/unSNAT") +} + +func TestSnatInfo_Valid(t *testing.T) { + t.Run("nil is invalid", func(t *testing.T) { + var s *snatInfo + assert.False(t, s.Valid()) + }) + + t.Run("zero value is invalid", func(t *testing.T) { + s := &snatInfo{} + assert.False(t, s.Valid()) + }) + + t.Run("with valid src is valid", func(t *testing.T) { + s := &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 1234), + SrcVpnIp: netip.MustParseAddr("fd00::1"), + SnatPort: 55555, + } + assert.True(t, s.Valid()) + }) +} + +func TestFirewall_ShouldUnSNAT(t *testing.T) { + snatAddr := netip.MustParseAddr("169.254.55.96") + + t.Run("no snat addr configured", func(t *testing.T) { + fw := &Firewall{} + fp := &firewall.Packet{RemoteAddr: snatAddr} + assert.False(t, fw.ShouldUnSNAT(fp)) + }) + + t.Run("packet to snat addr", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + fp := &firewall.Packet{RemoteAddr: snatAddr} + assert.True(t, fw.ShouldUnSNAT(fp)) + }) + + t.Run("packet to different addr", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + fp := &firewall.Packet{RemoteAddr: netip.MustParseAddr("10.0.0.1")} + assert.False(t, fw.ShouldUnSNAT(fp)) + }) +} + +func TestFirewall_IdentifyNetworkType_SNATPeer(t *testing.T) { + snatAddr := netip.MustParseAddr("169.254.55.96") + + t.Run("v4 packet from v6-only host without networks table", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("fd00::1")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + assert.Equal(t, NetworkTypeUnverifiedSNATPeer, fw.identifyRemoteNetworkType(h, fp)) + }) + + t.Run("v4 packet from v4 host is not snat peer", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.1")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + assert.Equal(t, NetworkTypeVPN, fw.identifyRemoteNetworkType(h, fp)) + }) + + t.Run("v6 packet from v6 host is VPN", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("fd00::1")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("fd00::1"), + LocalAddr: netip.MustParseAddr("fd00::2"), + } + assert.Equal(t, NetworkTypeVPN, fw.identifyRemoteNetworkType(h, fp)) + }) + + t.Run("mismatched v4 from v4 host is invalid", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.1")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.0.0.99"), + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + assert.Equal(t, NetworkTypeInvalidPeer, fw.identifyRemoteNetworkType(h, fp)) + }) +} + +func TestFirewall_AllowNetworkType_SNAT(t *testing.T) { + //todo fix! + //t.Run("snat peer allowed with snat addr", func(t *testing.T) { + // fw := &Firewall{snatAddr: netip.MustParseAddr("169.254.55.96")} + // assert.NoError(t, fw.allowRemoteNetworkType(NetworkTypeUnverifiedSNATPeer, fp)) + //}) + // + //t.Run("snat peer rejected without snat addr", func(t *testing.T) { + // fw := &Firewall{} + // assert.ErrorIs(t, fw.allowRemoteNetworkType(NetworkTypeUnverifiedSNATPeer, fp), ErrInvalidRemoteIP) + //}) + + t.Run("vpn always allowed", func(t *testing.T) { + fw := &Firewall{} + assert.NoError(t, fw.allowRemoteNetworkType(NetworkTypeVPN, firewall.Packet{})) + }) + + t.Run("unsafe always allowed", func(t *testing.T) { + fw := &Firewall{} + assert.NoError(t, fw.allowRemoteNetworkType(NetworkTypeUnsafe, firewall.Packet{})) + }) + + t.Run("invalid peer rejected", func(t *testing.T) { + fw := &Firewall{} + assert.ErrorIs(t, fw.allowRemoteNetworkType(NetworkTypeInvalidPeer, firewall.Packet{}), ErrInvalidRemoteIP) + }) + + t.Run("vpn peer rejected", func(t *testing.T) { + fw := &Firewall{} + assert.ErrorIs(t, fw.allowRemoteNetworkType(NetworkTypeVPNPeer, firewall.Packet{}), ErrPeerRejected) + }) +} + +func TestFirewall_FindUsableSNATPort(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + t.Run("finds first available port", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{snat: &snatInfo{}} + err := fw.findUsableSNATPort(&fp, cn) + require.NoError(t, err) + // Port should have been assigned + assert.Equal(t, uint16(12345), fp.RemotePort, "should use original port if available") + }) + + t.Run("skips occupied port", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + // Occupy the port + fw.Conntrack.Lock() + fw.Conntrack.Conns[fp] = &conn{} + fw.Conntrack.Unlock() + + cn := &conn{snat: &snatInfo{}} + err := fw.findUsableSNATPort(&fp, cn) + require.NoError(t, err) + assert.NotEqual(t, uint16(12345), fp.RemotePort, "should pick a different port") + }) + + t.Run("returns error on exhaustion", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + + // Fill all ports + baseFP := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + Protocol: firewall.ProtoUDP, + } + fw.Conntrack.Lock() + for i := 0; i < 65535; i++ { + fp := baseFP + fp.RemotePort = uint16(i) + fw.Conntrack.Conns[fp] = &conn{} + } + fw.Conntrack.Unlock() + + // Try to find a port starting from 0x8000 + fp := baseFP + fp.RemotePort = 0x8000 + cn := &conn{snat: &snatInfo{}} + err := fw.findUsableSNATPort(&fp, cn) + assert.ErrorIs(t, err, ErrCannotSNAT) + }) +} + +func TestFirewall_ApplySnat(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + peerV6Addr := netip.MustParseAddr("fd00::1") + dstIP := netip.MustParseAddr("192.168.1.1") + + t.Run("new flow from v6 host", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + h := &HostInfo{vpnAddrs: []netip.Addr{peerV6Addr}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.NoError(t, err) + + // Should have created snat info + require.True(t, cn.snat.Valid()) + assert.Equal(t, peerV6Addr, cn.snat.SrcVpnIp) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), cn.snat.Src.Addr()) + assert.Equal(t, uint16(12345), cn.snat.Src.Port()) + + // Packet source should be rewritten to snatAddr + gotSrcIP, _ := netip.AddrFromSlice(pkt[12:16]) + assert.Equal(t, snatAddr, gotSrcIP) + }) + + t.Run("existing flow with matching identity", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 12345), + SrcVpnIp: peerV6Addr, + SnatPort: 55555, + }, + } + h := &HostInfo{vpnAddrs: []netip.Addr{peerV6Addr}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.NoError(t, err) + + // Source should be rewritten + gotSrcIP, _ := netip.AddrFromSlice(pkt[12:16]) + assert.Equal(t, snatAddr, gotSrcIP) + }) + + t.Run("identity mismatch rejected", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 12345), + SrcVpnIp: netip.MustParseAddr("fd00::99"), // Different VPN IP + SnatPort: 55555, + }, + } + // Attacker has a different VPN address + h := &HostInfo{vpnAddrs: []netip.Addr{peerV6Addr}} + + err := fw.applySnat(pkt, &fp, cn, h) + assert.ErrorIs(t, err, ErrSNATIdentityMismatch) + }) + + t.Run("no snat addr configured", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + h := &HostInfo{vpnAddrs: []netip.Addr{peerV6Addr}} + + err := fw.applySnat(pkt, &fp, cn, h) + assert.ErrorIs(t, err, ErrCannotSNAT) + }) + + t.Run("v4 host rejected for new flow", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + // This host has a v4 address - can't SNAT for it + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.50")}} + + err := fw.applySnat(pkt, &fp, cn, h) + assert.ErrorIs(t, err, ErrCannotSNAT) + }) +} + +func TestFirewall_UnSnat(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + peerV6Addr := netip.MustParseAddr("fd00::1") + origSrcIP := netip.MustParseAddr("10.0.0.1") + + t.Run("successful unsnat", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw.snatAddr = snatAddr + + // Create a conntrack entry for the snatted flow + snatFP := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + fw.Conntrack.Lock() + fw.Conntrack.Conns[snatFP] = &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(origSrcIP, 12345), + SrcVpnIp: peerV6Addr, + SnatPort: 55555, + }, + } + fw.Conntrack.Unlock() + + pkt := slices.Clone(canonicalUDPReply) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + + result := fw.unSnat(pkt, &fp) + assert.True(t, result.IsValid()) + assert.Equal(t, peerV6Addr, result) + + // Destination should be rewritten to the original source + gotDstIP, _ := netip.AddrFromSlice(pkt[16:20]) + assert.Equal(t, origSrcIP, gotDstIP) + }) + + t.Run("no conntrack entry", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPReply) + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + + result := fw.unSnat(pkt, &fp) + assert.False(t, result.IsValid()) + }) +} + +func TestFirewall_Drop_SNATFullFlow(t *testing.T) { + // Integration test: a complete SNAT flow through Drop + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + myV6Prefix := netip.MustParsePrefix("fd00::1/128") + unsafeNet := netip.MustParsePrefix("192.168.0.0/16") + + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{myV6Prefix}, + unsafeNetworks: []netip.Prefix{unsafeNet}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) + fw.snatAddr = snatAddr + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) + + // Set up the peer: an IPv6-only host sending IPv4 traffic + peerV6Addr := netip.MustParseAddr("fd00::2") + peerCert := &dummyCert{ + name: "peer", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::2/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(myV6Prefix) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV6Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + pkt := slices.Clone(canonicalUDPSnatMe) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + // Drop should succeed and SNAT the packet + err := fw.Drop(fp, pkt, true, h, cp, nil) + require.NoError(t, err) + + // After Drop, the source should be rewritten to the snat addr + gotSrcIP, _ := netip.AddrFromSlice(pkt[12:16]) + assert.Equal(t, snatAddr, gotSrcIP) +} + +func TestHasOnlyV6Addresses(t *testing.T) { + t.Run("v6 only", func(t *testing.T) { + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("fd00::1"), + netip.MustParseAddr("fd00::2"), + }} + assert.True(t, h.HasOnlyV6Addresses()) + }) + + t.Run("v4 only", func(t *testing.T) { + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("10.0.0.1"), + }} + assert.False(t, h.HasOnlyV6Addresses()) + }) + + t.Run("mixed v4 and v6", func(t *testing.T) { + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("fd00::1"), + netip.MustParseAddr("10.0.0.1"), + }} + assert.False(t, h.HasOnlyV6Addresses()) + }) +} + +// --- Adversarial SNAT Tests --- + +func TestFirewall_ApplySnat_CrossHostHijack(t *testing.T) { + // Host A (fd00::1) establishes SNAT flow. Host B (fd00::2) sends a packet + // matching the same conntrack key but with a different identity. + // applySnat must reject with ErrSNATIdentityMismatch and leave the packet unmodified. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + hostA := netip.MustParseAddr("fd00::1") + hostB := netip.MustParseAddr("fd00::2") + + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw.snatAddr = snatAddr + + // Simulate Host A having established a flow + cn := &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 12345), + SrcVpnIp: hostA, + SnatPort: 55555, + }, + } + + // Host B tries to reuse the same conntrack entry + pkt := slices.Clone(canonicalUDPHijack) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + hB := &HostInfo{vpnAddrs: []netip.Addr{hostB}} + + err := fw.applySnat(pkt, &fp, cn, hB) + require.ErrorIs(t, err, ErrSNATIdentityMismatch) + assert.Equal(t, canonicalUDPHijack, pkt, "packet bytes must be unmodified after identity mismatch") +} + +func TestFirewall_ApplySnat_MixedStackRejected(t *testing.T) { + // A host with both v4 and v6 VPN addresses should never get SNAT treatment. + // Test both orderings of vpnAddrs to verify behavior. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + dstIP := netip.MustParseAddr("192.168.1.1") + + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + + t.Run("v6 first then v4", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + // Mixed-stack: v6 first. applySnat checks vpnAddrs[0].Is6() which is true, + // so it would create a flow. But the caller (Drop) guards with HasOnlyV6Addresses(). + // This test documents that applySnat alone doesn't prevent mixed-stack SNAT. + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("fd00::1"), + netip.MustParseAddr("10.0.0.50"), + }} + + err := fw.applySnat(pkt, &fp, cn, h) + // applySnat only checks vpnAddrs[0].Is6(), so this succeeds. + // The real guard is in Drop() via HasOnlyV6Addresses(). + assert.NoError(t, err, "applySnat alone allows v6-first mixed-stack (guarded by Drop)") + }) + + t.Run("v4 first then v6", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + // Mixed-stack: v4 first. vpnAddrs[0].Is6() is false -> ErrCannotSNAT. + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("10.0.0.50"), + netip.MustParseAddr("fd00::1"), + }} + + err := fw.applySnat(pkt, &fp, cn, h) + require.ErrorIs(t, err, ErrCannotSNAT) + assert.Equal(t, canonicalUDPTest, pkt, "packet bytes must be unmodified on error") + }) +} + +func TestFirewall_ApplySnat_PacketUnmodifiedOnError(t *testing.T) { + // When applySnat returns an error, the packet must not be partially rewritten. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + dstIP := netip.MustParseAddr("192.168.1.1") + + t.Run("no snatAddr configured", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + + pkt := slices.Clone(canonicalUDPTest) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("fd00::1")}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.Error(t, err) + assert.Equal(t, canonicalUDPTest, pkt, "packet must be byte-for-byte identical after error") + }) + + t.Run("identity mismatch", func(t *testing.T) { + snatAddr := netip.MustParseAddr("169.254.55.96") + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 12345), + SrcVpnIp: netip.MustParseAddr("fd00::99"), + SnatPort: 55555, + }, + } + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("fd00::1")}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.ErrorIs(t, err, ErrSNATIdentityMismatch) + assert.Equal(t, canonicalUDPTest, pkt, "packet must be byte-for-byte identical after identity mismatch") + }) + + t.Run("v4 host rejected", func(t *testing.T) { + snatAddr := netip.MustParseAddr("169.254.55.96") + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.50")}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.ErrorIs(t, err, ErrCannotSNAT) + assert.Equal(t, canonicalUDPTest, pkt, "packet must be byte-for-byte identical after v4 host rejection") + }) +} + +func TestFirewall_UnSnat_NonSNATConntrack(t *testing.T) { + // A conntrack entry exists but has snat=nil. unSnat should return an invalid addr + // and not rewrite the packet. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw.snatAddr = snatAddr + + // Create a conntrack entry with snat=nil (a normal non-SNAT connection) + snatFP := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + fw.Conntrack.Lock() + fw.Conntrack.Conns[snatFP] = &conn{ + snat: nil, // deliberately nil + } + fw.Conntrack.Unlock() + + pkt := slices.Clone(canonicalUDPReply) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + + result := fw.unSnat(pkt, &fp) + assert.False(t, result.IsValid(), "unSnat should return invalid addr for non-SNAT conntrack entry") + assert.Equal(t, canonicalUDPReply, pkt, "packet must not be rewritten when conntrack has no snat info") +} + +func TestFirewall_Drop_FirewallBlocksSNAT(t *testing.T) { + // Firewall rules only allow port 80. An SNAT-eligible packet to port 443 + // must be rejected with ErrNoMatchingRule BEFORE any SNAT rewriting occurs. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) + fw.snatAddr = snatAddr + // Only allow port 80 inbound + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 80, []string{"any"}, "", "", "any", "", "")) + + peerV6Addr := netip.MustParseAddr("fd00::2") + peerCert := &dummyCert{ + name: "peer", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::2/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::1/128")) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV6Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + // Send to port 443 (not allowed) + pkt := slices.Clone(canonicalUDPBlocked) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 443, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + err := fw.Drop(fp, pkt, true, h, cp, nil) + require.ErrorIs(t, err, ErrNoMatchingRule, "firewall should block SNAT-eligible traffic that doesn't match rules") + assert.Equal(t, canonicalUDPBlocked, pkt, "packet must not be rewritten when firewall blocks it") +} + +func TestFirewall_Drop_SNATLocalAddrNotRoutable(t *testing.T) { + // An SNAT peer sends IPv4 traffic to an address NOT in routableNetworks. + // willingToHandleLocalAddr should reject with ErrInvalidLocalIP. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) + fw.snatAddr = snatAddr + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) + + peerV6Addr := netip.MustParseAddr("fd00::2") + peerCert := &dummyCert{ + name: "peer", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::2/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::1/128")) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV6Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + // Dest 172.16.0.1 is NOT in our routableNetworks (which only has fd00::1/128 and 192.168.0.0/16) + pkt := slices.Clone(canonicalUDPWrongDest) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("172.16.0.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + err := fw.Drop(fp, pkt, true, h, cp, nil) + assert.ErrorIs(t, err, ErrInvalidLocalIP, "traffic to non-routable local address should be rejected") +} + +func TestFirewall_Drop_NoSnatAddrRejectsV6Peer(t *testing.T) { + // Firewall has no snatAddr configured. An IPv6-only peer sends IPv4 traffic. + // allowRemoteNetworkType(UncheckedSNATPeer) should reject with ErrInvalidRemoteIP. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) + + peerV6Addr := netip.MustParseAddr("fd00::2") + peerCert := &dummyCert{ + name: "peer", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::2/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::1/128")) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV6Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + pkt := slices.Clone(canonicalUDPNoSnat) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + err := fw.Drop(fp, pkt, true, h, cp, nil) + assert.ErrorIs(t, err, ErrInvalidRemoteIP, "v6 peer with no snatAddr should be rejected") +} + +func TestFirewall_Drop_IPv4HostNotSNATted(t *testing.T) { + // An IPv4 VPN host sends IPv4 traffic. Even though the router has snatAddr + // configured and the traffic is IPv4, the firewall must NOT treat this as + // UncheckedSNATPeer. The packet must not be SNAT-rewritten. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + t.Run("v6-only router rejects v4 peer as VPNPeer", func(t *testing.T) { + // When the router is v6-only, the v4 peer's address is outside our VPN + // networks -> classified as NetworkTypeVPNPeer -> rejected (not SNATted). + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) + fw.snatAddr = snatAddr + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) + + peerV4Addr := netip.MustParseAddr("10.128.0.2") + peerCert := &dummyCert{ + name: "v4peer", + networks: []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::1/128")) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV4Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + pkt := slices.Clone(canonicalUDPV4Traffic) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.128.0.2"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + err := fw.Drop(fp, pkt, true, h, cp, nil) + require.ErrorIs(t, err, ErrPeerRejected, "IPv4 peer should be rejected as VPNPeer, not treated as SNAT") + assert.Equal(t, canonicalUDPV4Traffic, pkt, "packet must not be rewritten when peer is rejected") + }) + + t.Run("identifyRemoteNetworkType classifies v4 peer correctly", func(t *testing.T) { + // Directly verify that identifyRemoteNetworkType returns the right type for + // an IPv4 peer (not UncheckedSNATPeer). + fw := &Firewall{snatAddr: snatAddr} + + // Simple case: v4 host, no networks table + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.128.0.2")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.128.0.2"), + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + nwType := fw.identifyRemoteNetworkType(h, fp) + assert.Equal(t, NetworkTypeVPN, nwType, "v4 peer using its own VPN addr should be NetworkTypeVPN") + assert.NotEqual(t, NetworkTypeUnverifiedSNATPeer, nwType, "must NOT be classified as SNAT peer") + }) + + t.Run("identifyRemoteNetworkType v4 peer with mismatched source", func(t *testing.T) { + // v4 host sends with a source IP that doesn't match its VPN addr + fw := &Firewall{snatAddr: snatAddr} + + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.128.0.2")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.0.0.99"), // Not the peer's VPN addr + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + nwType := fw.identifyRemoteNetworkType(h, fp) + assert.Equal(t, NetworkTypeInvalidPeer, nwType, "v4 peer with mismatched source should be InvalidPeer") + assert.NotEqual(t, NetworkTypeUnverifiedSNATPeer, nwType, "must NOT be classified as SNAT peer") + }) +} diff --git a/test/tun.go b/test/tun.go index fb32782f..e967568b 100644 --- a/test/tun.go +++ b/test/tun.go @@ -10,6 +10,18 @@ import ( type NoopTun struct{} +func (NoopTun) UnsafeNetworks() []netip.Prefix { + return nil +} + +func (NoopTun) SNATAddress() netip.Prefix { + return netip.Prefix{} +} + +func (NoopTun) UnsafeIPv4OriginAddress() netip.Prefix { + return netip.Prefix{} +} + func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { return routing.Gateways{} }