From 25610225bbff9ef46071486e0543f8d8fa95b47d Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 18 Feb 2026 15:07:57 -0600 Subject: [PATCH] crappy AI tests --- e2e/snat_test.go | 400 ++++++++++++ overlay/tun_snat_test.go | 171 +++++ snat_test.go | 1309 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 1880 insertions(+) create mode 100644 e2e/snat_test.go create mode 100644 overlay/tun_snat_test.go create mode 100644 snat_test.go diff --git a/e2e/snat_test.go b/e2e/snat_test.go new file mode 100644 index 00000000..a0e0af96 --- /dev/null +++ b/e2e/snat_test.go @@ -0,0 +1,400 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "encoding/binary" + "net/netip" + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/e2e/router" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// parseIPv4UDPPacket extracts source/dest IPs, ports, and payload from an IPv4 UDP packet. +func parseIPv4UDPPacket(t testing.TB, pkt []byte) (srcIP, dstIP netip.Addr, srcPort, dstPort uint16, payload []byte) { + t.Helper() + require.True(t, len(pkt) >= 28, "packet too short for IPv4+UDP header") + require.Equal(t, byte(0x45), pkt[0]&0xF0|pkt[0]&0x0F, "not a simple IPv4 packet (IHL!=5)") + + srcIP, _ = netip.AddrFromSlice(pkt[12:16]) + dstIP, _ = netip.AddrFromSlice(pkt[16:20]) + + ihl := int(pkt[0]&0x0F) * 4 + require.True(t, len(pkt) >= ihl+8, "packet too short for UDP header") + srcPort = binary.BigEndian.Uint16(pkt[ihl : ihl+2]) + dstPort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4]) + udpLen := binary.BigEndian.Uint16(pkt[ihl+4 : ihl+6]) + payload = pkt[ihl+8 : ihl+int(udpLen)] + return +} + +func TestSNAT_IPv6OnlyPeer_IPv4UnsafeTraffic(t *testing.T) { + // Scenario: Two IPv6-only VPN nodes. The "router" node has unsafe networks + // (192.168.0.0/16) in its cert and a configured SNAT address. The "sender" + // node has an unsafe route for 192.168.0.0/16 via the router. + // + // When sender injects an IPv4 packet destined for the unsafe network, it + // gets tunneled to the router. The router's firewall detects this is IPv4 + // from an IPv6-only peer and applies SNAT, rewriting the source IP to the + // SNAT address before delivering it to TUN. + // + // When a reply comes back from TUN addressed to the SNAT address, the + // router un-SNATs it (restoring the original destination) and tunnels it + // back to the sender. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + unsafePrefix := "192.168.0.0/16" + snatAddr := netip.MustParseAddr("169.254.42.42") + + // Router: IPv6-only with unsafe networks and a manual SNAT address. + // Override inbound firewall with local_cidr: "any" so both IPv4 (unsafe) + // and IPv6 (VPN) traffic is accepted. + routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks( + cert.Version2, ca, caKey, "router", "ff::1/64", + netip.MustParseAddrPort("[beef::1]:4242"), + unsafePrefix, + m{ + "firewall": m{ + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "any", + }}, + }, + "tun": m{ + "snat_address_for_4over6": snatAddr.String(), + }, + }, + ) + + // Sender: IPv6-only with an unsafe route via the router + senderControl, _, _, _ := newSimpleServerWithUdp( + cert.Version2, ca, caKey, "sender", "ff::2/64", + netip.MustParseAddrPort("[beef::2]:4242"), + m{ + "tun": m{ + "unsafe_routes": []m{ + {"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()}, + }, + }, + }, + ) + + // Tell sender where the router lives + senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr) + + // Build the router and start both nodes + r := router.NewR(t, routerControl, senderControl) + defer r.RenderFlow() + + routerControl.Start() + senderControl.Start() + + // --- Outbound: sender -> IPv4 unsafe dest (via router with SNAT) --- + + origSrcIP := netip.MustParseAddr("10.0.0.1") + unsafeDest := netip.MustParseAddr("192.168.1.1") + var origSrcPort uint16 = 12345 + var dstPort uint16 = 80 + + t.Log("Sender injects an IPv4 packet to the unsafe network") + senderControl.InjectTunUDPPacket(unsafeDest, dstPort, origSrcIP, origSrcPort, []byte("snat me")) + + t.Log("Route packets (handshake + data) until the router gets the packet on TUN") + snatPkt := r.RouteForAllUntilTxTun(routerControl) + + t.Log("Verify the packet was SNATted") + gotSrcIP, gotDstIP, gotSrcPort, gotDstPort, gotPayload := parseIPv4UDPPacket(t, snatPkt) + assert.Equal(t, snatAddr, gotSrcIP, "source IP should be rewritten to the SNAT address") + assert.Equal(t, unsafeDest, gotDstIP, "destination IP should be unchanged") + assert.Equal(t, dstPort, gotDstPort, "destination port should be unchanged") + assert.Equal(t, []byte("snat me"), gotPayload, "payload should be unchanged") + + // Capture the SNAT port (may differ from original if port was remapped) + snatPort := gotSrcPort + t.Logf("SNAT port: %d (original: %d)", snatPort, origSrcPort) + + // --- Return: reply from unsafe dest -> un-SNATted back to sender --- + + t.Log("Router injects a reply packet from the unsafe dest to the SNAT address") + routerControl.InjectTunUDPPacket(snatAddr, snatPort, unsafeDest, dstPort, []byte("reply from unsafe")) + + t.Log("Route until sender gets the reply on TUN") + replyPkt := r.RouteForAllUntilTxTun(senderControl) + + t.Log("Verify the reply was un-SNATted") + replySrcIP, replyDstIP, replySrcPort, replyDstPort, replyPayload := parseIPv4UDPPacket(t, replyPkt) + assert.Equal(t, unsafeDest, replySrcIP, "reply source should be the unsafe dest") + assert.Equal(t, origSrcIP, replyDstIP, "reply dest should be the original source IP (un-SNATted)") + assert.Equal(t, dstPort, replySrcPort, "reply source port should be the unsafe dest port") + assert.Equal(t, origSrcPort, replyDstPort, "reply dest port should be the original source port (un-SNATted)") + assert.Equal(t, []byte("reply from unsafe"), replyPayload, "payload should be unchanged") + + r.RenderHostmaps("Final hostmaps", routerControl, senderControl) + + // Also verify normal IPv6 VPN traffic still works between the nodes + t.Log("Verify normal IPv6 VPN tunnel still works") + assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r) + + routerControl.Stop() + senderControl.Stop() +} + +func TestSNAT_MultipleFlows(t *testing.T) { + // Test that multiple distinct IPv4 flows from the same IPv6-only peer + // are tracked independently through SNAT. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + unsafePrefix := "192.168.0.0/16" + snatAddr := netip.MustParseAddr("169.254.42.42") + + routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks( + cert.Version2, ca, caKey, "router", "ff::1/64", + netip.MustParseAddrPort("[beef::1]:4242"), + unsafePrefix, + m{ + "firewall": m{ + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "any", + }}, + }, + "tun": m{ + "snat_address_for_4over6": snatAddr.String(), + }, + }, + ) + + senderControl, _, _, _ := newSimpleServerWithUdp( + cert.Version2, ca, caKey, "sender", "ff::2/64", + netip.MustParseAddrPort("[beef::2]:4242"), + m{ + "tun": m{ + "unsafe_routes": []m{ + {"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()}, + }, + }, + }, + ) + + senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr) + + r := router.NewR(t, routerControl, senderControl) + defer r.RenderFlow() + r.CancelFlowLogs() + + routerControl.Start() + senderControl.Start() + + unsafeDest := netip.MustParseAddr("192.168.1.1") + + // Send first flow + senderControl.InjectTunUDPPacket(unsafeDest, 80, netip.MustParseAddr("10.0.0.1"), 1111, []byte("flow1")) + pkt1 := r.RouteForAllUntilTxTun(routerControl) + srcIP1, _, srcPort1, _, payload1 := parseIPv4UDPPacket(t, pkt1) + assert.Equal(t, snatAddr, srcIP1) + assert.Equal(t, []byte("flow1"), payload1) + + // Send second flow (different source port) + senderControl.InjectTunUDPPacket(unsafeDest, 80, netip.MustParseAddr("10.0.0.1"), 2222, []byte("flow2")) + pkt2 := r.RouteForAllUntilTxTun(routerControl) + srcIP2, _, srcPort2, _, payload2 := parseIPv4UDPPacket(t, pkt2) + assert.Equal(t, snatAddr, srcIP2) + assert.Equal(t, []byte("flow2"), payload2) + + // The two flows should have different SNAT ports (since they're different conntracks) + t.Logf("Flow 1 SNAT port: %d, Flow 2 SNAT port: %d", srcPort1, srcPort2) + + // Reply to flow 2 first (out of order) + routerControl.InjectTunUDPPacket(snatAddr, srcPort2, unsafeDest, 80, []byte("reply2")) + reply2 := r.RouteForAllUntilTxTun(senderControl) + _, replyDst2, _, replyDstPort2, replyPayload2 := parseIPv4UDPPacket(t, reply2) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), replyDst2) + assert.Equal(t, uint16(2222), replyDstPort2, "reply to flow 2 should restore original port 2222") + assert.Equal(t, []byte("reply2"), replyPayload2) + + // Reply to flow 1 + routerControl.InjectTunUDPPacket(snatAddr, srcPort1, unsafeDest, 80, []byte("reply1")) + reply1 := r.RouteForAllUntilTxTun(senderControl) + _, replyDst1, _, replyDstPort1, replyPayload1 := parseIPv4UDPPacket(t, reply1) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), replyDst1) + assert.Equal(t, uint16(1111), replyDstPort1, "reply to flow 1 should restore original port 1111") + assert.Equal(t, []byte("reply1"), replyPayload1) + + routerControl.Stop() + senderControl.Stop() +} + +// --- Adversarial SNAT E2E Tests --- + +func TestSNAT_UnsolicitedReplyDropped(t *testing.T) { + // Without any outbound SNAT traffic, inject a packet from the router's TUN + // addressed to the SNAT address. The sender must never receive it because + // there's no conntrack entry to un-SNAT through. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + unsafePrefix := "192.168.0.0/16" + snatAddr := netip.MustParseAddr("169.254.42.42") + + routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks( + cert.Version2, ca, caKey, "router", "ff::1/64", + netip.MustParseAddrPort("[beef::1]:4242"), + unsafePrefix, + m{ + "firewall": m{ + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "any", + }}, + }, + "tun": m{ + "snat_address_for_4over6": snatAddr.String(), + }, + }, + ) + + senderControl, _, _, _ := newSimpleServerWithUdp( + cert.Version2, ca, caKey, "sender", "ff::2/64", + netip.MustParseAddrPort("[beef::2]:4242"), + m{ + "tun": m{ + "unsafe_routes": []m{ + {"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()}, + }, + }, + }, + ) + + senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr) + + r := router.NewR(t, routerControl, senderControl) + defer r.RenderFlow() + r.CancelFlowLogs() + + routerControl.Start() + senderControl.Start() + + // First establish the tunnel with normal IPv6 traffic so handshake completes + assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r) + + // Inject the unsolicited reply from router's TUN to the SNAT address. + // There is NO prior outbound SNAT flow, so no conntrack entry exists. + // The router should silently drop this because unSnat finds no matching conntrack. + routerControl.InjectTunUDPPacket(snatAddr, 55555, netip.MustParseAddr("192.168.1.1"), 80, []byte("unsolicited")) + + // Send a canary IPv6 VPN packet after the bad one. Since the router processes + // TUN packets sequentially, the canary arriving proves the bad packet was processed first. + senderVpnAddr := senderControl.GetVpnAddrs()[0] + routerControl.InjectTunUDPPacket(senderVpnAddr, 90, routerVpnIpNet[0].Addr(), 80, []byte("canary")) + canaryPkt := r.RouteForAllUntilTxTun(senderControl) + assertUdpPacket(t, []byte("canary"), canaryPkt, routerVpnIpNet[0].Addr(), senderVpnAddr, 80, 90) + + // The unsolicited packet should have been dropped — nothing else on sender's TUN + got := senderControl.GetFromTun(false) + assert.Nil(t, got, "sender should not receive unsolicited packet to SNAT address with no conntrack entry") + + routerControl.Stop() + senderControl.Stop() +} + +func TestSNAT_NonUnsafeDestDropped(t *testing.T) { + // An IPv6-only sender sends IPv4 traffic to a destination outside the router's + // unsafe networks (172.16.0.1 when unsafe is 192.168.0.0/16). The router should + // reject this because the local address is not routable. This verifies that + // willingToHandleLocalAddr enforces boundaries on what SNAT traffic can reach. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + unsafePrefix := "192.168.0.0/16" + snatAddr := netip.MustParseAddr("169.254.42.42") + + routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks( + cert.Version2, ca, caKey, "router", "ff::1/64", + netip.MustParseAddrPort("[beef::1]:4242"), + unsafePrefix, + m{ + "firewall": m{ + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "any", + }}, + }, + "tun": m{ + "snat_address_for_4over6": snatAddr.String(), + }, + }, + ) + + // Sender has unsafe routes for BOTH 192.168.0.0/16 AND 172.16.0.0/12 via router. + // This means the sender will route 172.16.0.1 through the tunnel to the router. + // But the router should reject it because 172.16.0.0/12 is NOT in its unsafe networks. + senderControl, _, _, _ := newSimpleServerWithUdp( + cert.Version2, ca, caKey, "sender", "ff::2/64", + netip.MustParseAddrPort("[beef::2]:4242"), + m{ + "tun": m{ + "unsafe_routes": []m{ + {"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()}, + {"route": "172.16.0.0/12", "via": routerVpnIpNet[0].Addr().String()}, + }, + }, + }, + ) + + senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr) + + r := router.NewR(t, routerControl, senderControl) + defer r.RenderFlow() + r.CancelFlowLogs() + + routerControl.Start() + senderControl.Start() + + // Establish the tunnel first + assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r) + + // Send to 172.16.0.1 (NOT in router's unsafe networks 192.168.0.0/16). + // The router should reject this at willingToHandleLocalAddr. + senderControl.InjectTunUDPPacket( + netip.MustParseAddr("172.16.0.1"), 80, + netip.MustParseAddr("10.0.0.1"), 12345, + []byte("wrong dest"), + ) + + // Send a canary to a valid unsafe destination to prove the bad packet was processed + senderControl.InjectTunUDPPacket( + netip.MustParseAddr("192.168.1.1"), 80, + netip.MustParseAddr("10.0.0.1"), 33333, + []byte("canary"), + ) + + // Route until the canary arrives — the 172.16.0.1 packet should have been + // processed and dropped before the canary gets through + canaryPkt := r.RouteForAllUntilTxTun(routerControl) + _, canaryDst, _, _, canaryPayload := parseIPv4UDPPacket(t, canaryPkt) + assert.Equal(t, netip.MustParseAddr("192.168.1.1"), canaryDst, "canary should arrive at the valid unsafe dest") + assert.Equal(t, []byte("canary"), canaryPayload) + + // No more packets — the 172.16.0.1 packet was dropped + got := routerControl.GetFromTun(false) + assert.Nil(t, got, "packet to non-unsafe destination 172.16.0.1 should be dropped by the router") + + routerControl.Stop() + senderControl.Stop() +} diff --git a/overlay/tun_snat_test.go b/overlay/tun_snat_test.go new file mode 100644 index 00000000..0040edb4 --- /dev/null +++ b/overlay/tun_snat_test.go @@ -0,0 +1,171 @@ +package overlay + +import ( + "io" + "net/netip" + "testing" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockDevice is a minimal Device implementation for testing prepareSnatAddr. +type mockDevice struct { + networks []netip.Prefix + unsafeNetworks []netip.Prefix + snatAddr netip.Prefix +} + +func (d *mockDevice) Read([]byte) (int, error) { return 0, nil } +func (d *mockDevice) Write([]byte) (int, error) { return 0, nil } +func (d *mockDevice) Close() error { return nil } +func (d *mockDevice) Activate() error { return nil } +func (d *mockDevice) Networks() []netip.Prefix { return d.networks } +func (d *mockDevice) UnsafeNetworks() []netip.Prefix { return d.unsafeNetworks } +func (d *mockDevice) SNATAddress() netip.Prefix { return d.snatAddr } +func (d *mockDevice) Name() string { return "mock" } +func (d *mockDevice) RoutesFor(netip.Addr) routing.Gateways { return routing.Gateways{} } +func (d *mockDevice) SupportsMultiqueue() bool { return false } +func (d *mockDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, nil } + +func TestPrepareSnatAddr_V4Primary_NoSnat(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // If the device has an IPv4 primary address, no SNAT needed + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + } + result := prepareSnatAddr(d, l, c, nil) + assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr when device has IPv4 primary") +} + +func TestPrepareSnatAddr_V6Primary_NoUnsafeOrRoutes(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // IPv6 primary but no unsafe networks or IPv4 routes + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + } + result := prepareSnatAddr(d, l, c, nil) + assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr without IPv4 unsafe networks or routes") +} + +func TestPrepareSnatAddr_V6Primary_WithV4Unsafe(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // IPv6 primary with IPv4 unsafe network -> should get SNAT addr + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + result := prepareSnatAddr(d, l, c, nil) + require.True(t, result.IsValid(), "should assign SNAT addr") + assert.True(t, result.Addr().Is4(), "SNAT addr should be IPv4") + assert.True(t, result.Addr().IsLinkLocalUnicast(), "SNAT addr should be link-local") + assert.Equal(t, 32, result.Bits(), "SNAT addr should be /32") +} + +func TestPrepareSnatAddr_V6Primary_WithV4Route(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // IPv6 primary with IPv4 route -> should get SNAT addr + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + } + routes := []Route{ + {Cidr: netip.MustParsePrefix("10.0.0.0/8")}, + } + result := prepareSnatAddr(d, l, c, routes) + require.True(t, result.IsValid(), "should assign SNAT addr when IPv4 route exists") + assert.True(t, result.Addr().Is4()) + assert.True(t, result.Addr().IsLinkLocalUnicast()) +} + +func TestPrepareSnatAddr_V6Primary_V6UnsafeOnly(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // IPv6 primary with only IPv6 unsafe network -> no SNAT needed + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("fd01::/64")}, + } + result := prepareSnatAddr(d, l, c, nil) + assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr for IPv6-only unsafe networks") +} + +func TestPrepareSnatAddr_ManualAddress(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + c.Settings["tun"] = map[string]any{ + "snat_address_for_4over6": "169.254.42.42", + } + + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + result := prepareSnatAddr(d, l, c, nil) + require.True(t, result.IsValid()) + assert.Equal(t, netip.MustParseAddr("169.254.42.42"), result.Addr()) + assert.Equal(t, 32, result.Bits()) +} + +func TestPrepareSnatAddr_InvalidManualAddress_Fallback(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + c.Settings["tun"] = map[string]any{ + "snat_address_for_4over6": "not-an-ip", + } + + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + result := prepareSnatAddr(d, l, c, nil) + // Should fall back to auto-assignment + require.True(t, result.IsValid(), "should fall back to auto-assigned address") + assert.True(t, result.Addr().Is4()) + assert.True(t, result.Addr().IsLinkLocalUnicast()) +} + +func TestPrepareSnatAddr_AutoGenerated_Range(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + + // Generate several addresses and verify they're all in the expected range + for i := 0; i < 100; i++ { + result := prepareSnatAddr(d, l, c, nil) + require.True(t, result.IsValid()) + addr := result.Addr() + octets := addr.As4() + assert.Equal(t, byte(169), octets[0], "first octet should be 169") + assert.Equal(t, byte(254), octets[1], "second octet should be 254") + // Should not have .0 in the last octet + assert.NotEqual(t, byte(0), octets[3], "last octet should not be 0") + // Should not be 169.254.255.255 (broadcast) + if octets[2] == 255 { + assert.NotEqual(t, byte(255), octets[3], "should not be broadcast address") + } + } +} diff --git a/snat_test.go b/snat_test.go new file mode 100644 index 00000000..b6e2a116 --- /dev/null +++ b/snat_test.go @@ -0,0 +1,1309 @@ +package nebula + +import ( + "encoding/binary" + "net/netip" + "slices" + "testing" + "time" + + "github.com/gaissmai/bart" + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/firewall" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Canonical test packets with all checksums computed from scratch by +// /tmp/gen_canonical.go. Tests feed these into the production rewrite +// functions and compare byte-for-byte against expected outputs. + +// canonicalUDP: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="hello world" +var canonicalUDP = []byte{ + 0x45, 0x00, 0x00, 0x27, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xe8, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x13, 0x71, 0xc6, 0x68, 0x65, 0x6c, 0x6c, + 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, +} + +// canonicalTCP: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=TCP payload="GET / HTTP/1.1" +var canonicalTCP = []byte{ + 0x45, 0x00, 0x00, 0x36, 0x12, 0x34, 0x40, 0x00, 0x40, 0x06, 0x5c, 0xe4, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x12, 0x34, 0x56, 0x78, 0x00, 0x00, 0x00, 0x00, + 0x50, 0x02, 0xff, 0xff, 0x86, 0x68, 0x00, 0x00, 0x47, 0x45, 0x54, 0x20, 0x2f, 0x20, 0x48, 0x54, + 0x54, 0x50, 0x2f, 0x31, 0x2e, 0x31, +} + +// canonicalICMP: src=10.0.0.1 dst=192.168.1.1 proto=ICMP echo, id=0x1234 seq=1 +var canonicalICMP = []byte{ + 0x45, 0x00, 0x00, 0x1c, 0x12, 0x34, 0x40, 0x00, 0x40, 0x01, 0x5d, 0x03, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x08, 0x00, 0xe5, 0xca, 0x12, 0x34, 0x00, 0x01, +} + +// canonicalUDPReply: src=192.168.1.1:80 dst=169.254.55.96:55555 proto=UDP payload="reply" +var canonicalUDPReply = []byte{ + 0x45, 0x00, 0x00, 0x21, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x85, 0x90, 0xc0, 0xa8, 0x01, 0x01, + 0xa9, 0xfe, 0x37, 0x60, 0x00, 0x50, 0xd9, 0x03, 0x00, 0x0d, 0x27, 0xa6, 0x72, 0x65, 0x70, 0x6c, + 0x79, +} + +// canonicalUDPTest: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="test" +var canonicalUDPTest = []byte{ + 0x45, 0x00, 0x00, 0x20, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xef, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x0c, 0x1b, 0xc9, 0x74, 0x65, 0x73, 0x74, +} + +// canonicalUDPHijack: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="hijack" +var canonicalUDPHijack = []byte{ + 0x45, 0x00, 0x00, 0x22, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xed, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x0e, 0xcd, 0x68, 0x68, 0x69, 0x6a, 0x61, + 0x63, 0x6b, +} + +// canonicalUDPBlocked: src=10.0.0.1:12345 dst=192.168.1.1:443 proto=UDP payload="blocked" +var canonicalUDPBlocked = []byte{ + 0x45, 0x00, 0x00, 0x23, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xec, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x01, 0xbb, 0x00, 0x0f, 0x60, 0xfc, 0x62, 0x6c, 0x6f, 0x63, + 0x6b, 0x65, 0x64, +} + +// canonicalUDPWrongDest: src=10.0.0.1:12345 dst=172.16.0.1:80 proto=UDP payload="wrong dest" +var canonicalUDPWrongDest = []byte{ + 0x45, 0x00, 0x00, 0x26, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x72, 0x81, 0x0a, 0x00, 0x00, 0x01, + 0xac, 0x10, 0x00, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x12, 0xf3, 0x53, 0x77, 0x72, 0x6f, 0x6e, + 0x67, 0x20, 0x64, 0x65, 0x73, 0x74, +} + +// canonicalUDPNoSnat: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="no snat" +var canonicalUDPNoSnat = []byte{ + 0x45, 0x00, 0x00, 0x23, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xec, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x0f, 0x92, 0x58, 0x6e, 0x6f, 0x20, 0x73, + 0x6e, 0x61, 0x74, +} + +// canonicalUDPV4Traffic: src=10.128.0.2:12345 dst=192.168.1.1:80 proto=UDP payload="v4 traffic" +var canonicalUDPV4Traffic = []byte{ + 0x45, 0x00, 0x00, 0x26, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0x68, 0x0a, 0x80, 0x00, 0x02, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x12, 0x2a, 0x42, 0x76, 0x34, 0x20, 0x74, + 0x72, 0x61, 0x66, 0x66, 0x69, 0x63, +} + +// canonicalUDPRoundtrip: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="roundtrip" +var canonicalUDPRoundtrip = []byte{ + 0x45, 0x00, 0x00, 0x25, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xea, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x11, 0xd4, 0xdc, 0x72, 0x6f, 0x75, 0x6e, + 0x64, 0x74, 0x72, 0x69, 0x70, +} + +// canonicalUDPSnatMe: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="snat me" +var canonicalUDPSnatMe = []byte{ + 0x45, 0x00, 0x00, 0x23, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xec, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x0f, 0xa9, 0x4c, 0x73, 0x6e, 0x61, 0x74, + 0x20, 0x6d, 0x65, +} + +// Expected outputs after rewriting — built from scratch with the post-rewrite +// addresses, so all checksums are independently correct. + +// canonicalUDPSnatted: canonicalUDP with src rewritten to 169.254.55.96:55555 +var canonicalUDPSnatted = []byte{ + 0x45, 0x00, 0x00, 0x27, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x85, 0x8a, 0xa9, 0xfe, 0x37, 0x60, + 0xc0, 0xa8, 0x01, 0x01, 0xd9, 0x03, 0x00, 0x50, 0x00, 0x13, 0xf1, 0x9d, 0x68, 0x65, 0x6c, 0x6c, + 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, +} + +// canonicalUDPReplyUnSnatted: canonicalUDPReply with dst rewritten from 169.254.55.96:55555 to 10.0.0.1:12345 +var canonicalUDPReplyUnSnatted = []byte{ + 0x45, 0x00, 0x00, 0x21, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xee, 0xc0, 0xa8, 0x01, 0x01, + 0x0a, 0x00, 0x00, 0x01, 0x00, 0x50, 0x30, 0x39, 0x00, 0x0d, 0xa7, 0xce, 0x72, 0x65, 0x70, 0x6c, + 0x79, +} + +// canonicalTCPSnatted: canonicalTCP with src rewritten to 169.254.55.96:55555 +var canonicalTCPSnatted = []byte{ + 0x45, 0x00, 0x00, 0x36, 0x12, 0x34, 0x40, 0x00, 0x40, 0x06, 0x85, 0x86, 0xa9, 0xfe, 0x37, 0x60, + 0xc0, 0xa8, 0x01, 0x01, 0xd9, 0x03, 0x00, 0x50, 0x12, 0x34, 0x56, 0x78, 0x00, 0x00, 0x00, 0x00, + 0x50, 0x02, 0xff, 0xff, 0x06, 0x40, 0x00, 0x00, 0x47, 0x45, 0x54, 0x20, 0x2f, 0x20, 0x48, 0x54, + 0x54, 0x50, 0x2f, 0x31, 0x2e, 0x31, +} + +// canonicalICMPSnatted: canonicalICMP with src rewritten to 169.254.55.96, id changed from 0x1234 to 0x5678 +var canonicalICMPSnatted = []byte{ + 0x45, 0x00, 0x00, 0x1c, 0x12, 0x34, 0x40, 0x00, 0x40, 0x01, 0x85, 0xa5, 0xa9, 0xfe, 0x37, 0x60, + 0xc0, 0xa8, 0x01, 0x01, 0x08, 0x00, 0xa1, 0x86, 0x56, 0x78, 0x00, 0x01, +} + +func TestCalcNewTransportChecksum_Identity(t *testing.T) { + // Rewriting to the same IP/port should return the same checksum + ip := netip.MustParseAddr("10.0.0.1") + result := calcNewTransportChecksum(0x1234, ip, 80, ip, 80) + assert.Equal(t, uint16(0x1234), result) +} + +func TestCalcNewTransportChecksum_VsCanonical(t *testing.T) { + srcIP := netip.MustParseAddr("10.0.0.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + // Extract the original UDP checksum from canonicalUDP (bytes 26-27) + origChecksum := binary.BigEndian.Uint16(canonicalUDP[26:28]) + + // Compute incrementally + incremental := calcNewTransportChecksum(origChecksum, srcIP, 12345, snatIP, 55555) + + // Verify it matches the checksum in the independently-computed canonicalUDPSnatted + expectedChecksum := binary.BigEndian.Uint16(canonicalUDPSnatted[26:28]) + assert.Equal(t, expectedChecksum, incremental, "incremental checksum should match canonical expected output") +} + +func TestCalcNewICMPChecksum_Identity(t *testing.T) { + // Same values in and out should be identity + result := calcNewICMPChecksum(0xABCD, 0, 0, 1234, 1234) + assert.Equal(t, uint16(0xABCD), result) +} + +func TestRewritePacket_UDP(t *testing.T) { + srcIP := netip.MustParseAddr("10.0.0.1") + dstIP := netip.MustParseAddr("192.168.1.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + pkt := slices.Clone(canonicalUDP) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: srcIP, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + + // SNAT rewrites source: IP at offset 12, port at offset 0 inside transport + oldIP := netip.AddrPortFrom(srcIP, 12345) + newIP := netip.AddrPortFrom(snatIP, 55555) + rewritePacket(pkt, &fp, oldIP, newIP, 12, 0) + + assert.Equal(t, canonicalUDPSnatted, pkt, "rewritten packet should match canonical expected output") +} + +func TestRewritePacket_UDP_UnSNAT(t *testing.T) { + snatIP := netip.MustParseAddr("169.254.55.96") + dstIP := netip.MustParseAddr("192.168.1.1") + origSrcIP := netip.MustParseAddr("10.0.0.1") + + pkt := slices.Clone(canonicalUDPReply) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: snatIP, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + + // UnSNAT rewrites destination: IP at offset 16, port at offset 2 inside transport + oldIP := netip.AddrPortFrom(snatIP, 55555) + newIP := netip.AddrPortFrom(origSrcIP, 12345) + rewritePacket(pkt, &fp, oldIP, newIP, 16, 2) + + assert.Equal(t, canonicalUDPReplyUnSnatted, pkt, "un-SNATted packet should match canonical expected output") +} + +func TestRewritePacket_TCP(t *testing.T) { + srcIP := netip.MustParseAddr("10.0.0.1") + dstIP := netip.MustParseAddr("192.168.1.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + pkt := slices.Clone(canonicalTCP) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: srcIP, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoTCP, + } + + oldIP := netip.AddrPortFrom(srcIP, 12345) + newIP := netip.AddrPortFrom(snatIP, 55555) + rewritePacket(pkt, &fp, oldIP, newIP, 12, 0) + + assert.Equal(t, canonicalTCPSnatted, pkt, "rewritten TCP packet should match canonical expected output") +} + +func TestRewritePacket_ICMP(t *testing.T) { + srcIP := netip.MustParseAddr("10.0.0.1") + dstIP := netip.MustParseAddr("192.168.1.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + pkt := slices.Clone(canonicalICMP) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: srcIP, + LocalPort: 0, + RemotePort: 0x1234, // ICMP ID used as port + Protocol: firewall.ProtoICMP, + } + + oldIP := netip.AddrPortFrom(srcIP, 0x1234) + newIP := netip.AddrPortFrom(snatIP, 0x5678) + rewritePacket(pkt, &fp, oldIP, newIP, 12, 0) + + assert.Equal(t, canonicalICMPSnatted, pkt, "rewritten ICMP packet should match canonical expected output") +} + +func TestRewritePacket_Roundtrip(t *testing.T) { + // Test that SNAT followed by unSNAT produces the original packet + srcIP := netip.MustParseAddr("10.0.0.1") + dstIP := netip.MustParseAddr("192.168.1.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + pkt := slices.Clone(canonicalUDPRoundtrip) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: srcIP, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + + // SNAT: rewrite source + oldSrc := netip.AddrPortFrom(srcIP, 12345) + newSrc := netip.AddrPortFrom(snatIP, 55555) + rewritePacket(pkt, &fp, oldSrc, newSrc, 12, 0) + + // Verify intermediate state is not the original + require.NotEqual(t, canonicalUDPRoundtrip, pkt) + + // UnSNAT: rewrite source back + rewritePacket(pkt, &fp, newSrc, oldSrc, 12, 0) + + // Packet should be byte-for-byte identical to original + assert.Equal(t, canonicalUDPRoundtrip, pkt, "packet should be identical after roundtrip SNAT/unSNAT") +} + +func TestSnatInfo_Valid(t *testing.T) { + t.Run("nil is invalid", func(t *testing.T) { + var s *snatInfo + assert.False(t, s.Valid()) + }) + + t.Run("zero value is invalid", func(t *testing.T) { + s := &snatInfo{} + assert.False(t, s.Valid()) + }) + + t.Run("with valid src is valid", func(t *testing.T) { + s := &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 1234), + SrcVpnIp: netip.MustParseAddr("fd00::1"), + SnatPort: 55555, + } + assert.True(t, s.Valid()) + }) +} + +func TestFirewall_ShouldUnSNAT(t *testing.T) { + snatAddr := netip.MustParseAddr("169.254.55.96") + + t.Run("no snat addr configured", func(t *testing.T) { + fw := &Firewall{} + fp := &firewall.Packet{RemoteAddr: snatAddr} + assert.False(t, fw.ShouldUnSNAT(fp)) + }) + + t.Run("packet to snat addr", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + fp := &firewall.Packet{RemoteAddr: snatAddr} + assert.True(t, fw.ShouldUnSNAT(fp)) + }) + + t.Run("packet to different addr", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + fp := &firewall.Packet{RemoteAddr: netip.MustParseAddr("10.0.0.1")} + assert.False(t, fw.ShouldUnSNAT(fp)) + }) +} + +func TestFirewall_IdentifyNetworkType_SNATPeer(t *testing.T) { + snatAddr := netip.MustParseAddr("169.254.55.96") + + t.Run("v4 packet from v6-only host without networks table", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("fd00::1")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + assert.Equal(t, NetworkTypeUncheckedSNATPeer, fw.identifyNetworkType(h, fp)) + }) + + t.Run("v4 packet from v4 host is not snat peer", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.1")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + assert.Equal(t, NetworkTypeVPN, fw.identifyNetworkType(h, fp)) + }) + + t.Run("v6 packet from v6 host is VPN", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("fd00::1")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("fd00::1"), + LocalAddr: netip.MustParseAddr("fd00::2"), + } + assert.Equal(t, NetworkTypeVPN, fw.identifyNetworkType(h, fp)) + }) + + t.Run("mismatched v4 from v4 host is invalid", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.1")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.0.0.99"), + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + assert.Equal(t, NetworkTypeInvalidPeer, fw.identifyNetworkType(h, fp)) + }) +} + +func TestFirewall_AllowNetworkType_SNAT(t *testing.T) { + t.Run("snat peer allowed with snat addr", func(t *testing.T) { + fw := &Firewall{snatAddr: netip.MustParseAddr("169.254.55.96")} + assert.NoError(t, fw.allowNetworkType(NetworkTypeUncheckedSNATPeer)) + }) + + t.Run("snat peer rejected without snat addr", func(t *testing.T) { + fw := &Firewall{} + assert.ErrorIs(t, fw.allowNetworkType(NetworkTypeUncheckedSNATPeer), ErrInvalidRemoteIP) + }) + + t.Run("vpn always allowed", func(t *testing.T) { + fw := &Firewall{} + assert.NoError(t, fw.allowNetworkType(NetworkTypeVPN)) + }) + + t.Run("unsafe always allowed", func(t *testing.T) { + fw := &Firewall{} + assert.NoError(t, fw.allowNetworkType(NetworkTypeUnsafe)) + }) + + t.Run("invalid peer rejected", func(t *testing.T) { + fw := &Firewall{} + assert.ErrorIs(t, fw.allowNetworkType(NetworkTypeInvalidPeer), ErrInvalidRemoteIP) + }) + + t.Run("vpn peer rejected", func(t *testing.T) { + fw := &Firewall{} + assert.ErrorIs(t, fw.allowNetworkType(NetworkTypeVPNPeer), ErrPeerRejected) + }) +} + +func TestFirewall_FindUsableSNATPort(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + t.Run("finds first available port", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + err := fw.findUsableSNATPort(&fp, cn) + require.NoError(t, err) + // Port should have been assigned + assert.Equal(t, uint16(12345), fp.RemotePort, "should use original port if available") + }) + + t.Run("skips occupied port", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + // Occupy the port + fw.Conntrack.Lock() + fw.Conntrack.Conns[fp] = &conn{} + fw.Conntrack.Unlock() + + cn := &conn{} + err := fw.findUsableSNATPort(&fp, cn) + require.NoError(t, err) + assert.NotEqual(t, uint16(12345), fp.RemotePort, "should pick a different port") + }) + + t.Run("returns error on exhaustion", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + + // Fill all 0x7ff ports + baseFP := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + Protocol: firewall.ProtoUDP, + } + fw.Conntrack.Lock() + for i := 0; i < 0x7ff; i++ { + fp := baseFP + fp.RemotePort = uint16(0x7ff + i) + fw.Conntrack.Conns[fp] = &conn{} + } + fw.Conntrack.Unlock() + + // Try to find a port starting from 0x7ff + fp := baseFP + fp.RemotePort = 0x7ff + cn := &conn{} + err := fw.findUsableSNATPort(&fp, cn) + assert.ErrorIs(t, err, ErrCannotSNAT) + }) +} + +func TestFirewall_ApplySnat(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + peerV6Addr := netip.MustParseAddr("fd00::1") + dstIP := netip.MustParseAddr("192.168.1.1") + + t.Run("new flow from v6 host", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + h := &HostInfo{vpnAddrs: []netip.Addr{peerV6Addr}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.NoError(t, err) + + // Should have created snat info + require.True(t, cn.snat.Valid()) + assert.Equal(t, peerV6Addr, cn.snat.SrcVpnIp) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), cn.snat.Src.Addr()) + assert.Equal(t, uint16(12345), cn.snat.Src.Port()) + + // Packet source should be rewritten to snatAddr + gotSrcIP, _ := netip.AddrFromSlice(pkt[12:16]) + assert.Equal(t, snatAddr, gotSrcIP) + }) + + t.Run("existing flow with matching identity", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 12345), + SrcVpnIp: peerV6Addr, + SnatPort: 55555, + }, + } + h := &HostInfo{vpnAddrs: []netip.Addr{peerV6Addr}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.NoError(t, err) + + // Source should be rewritten + gotSrcIP, _ := netip.AddrFromSlice(pkt[12:16]) + assert.Equal(t, snatAddr, gotSrcIP) + }) + + t.Run("identity mismatch rejected", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 12345), + SrcVpnIp: netip.MustParseAddr("fd00::99"), // Different VPN IP + SnatPort: 55555, + }, + } + // Attacker has a different VPN address + h := &HostInfo{vpnAddrs: []netip.Addr{peerV6Addr}} + + err := fw.applySnat(pkt, &fp, cn, h) + assert.ErrorIs(t, err, ErrSNATIdentityMismatch) + }) + + t.Run("no snat addr configured", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + h := &HostInfo{vpnAddrs: []netip.Addr{peerV6Addr}} + + err := fw.applySnat(pkt, &fp, cn, h) + assert.ErrorIs(t, err, ErrCannotSNAT) + }) + + t.Run("v4 host rejected for new flow", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + // This host has a v4 address - can't SNAT for it + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.50")}} + + err := fw.applySnat(pkt, &fp, cn, h) + assert.ErrorIs(t, err, ErrCannotSNAT) + }) +} + +func TestFirewall_UnSnat(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + peerV6Addr := netip.MustParseAddr("fd00::1") + origSrcIP := netip.MustParseAddr("10.0.0.1") + + t.Run("successful unsnat", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + // Create a conntrack entry for the snatted flow + snatFP := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + fw.Conntrack.Lock() + fw.Conntrack.Conns[snatFP] = &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(origSrcIP, 12345), + SrcVpnIp: peerV6Addr, + SnatPort: 55555, + }, + } + fw.Conntrack.Unlock() + + pkt := slices.Clone(canonicalUDPReply) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + + result := fw.unSnat(pkt, &fp) + assert.True(t, result.IsValid()) + assert.Equal(t, peerV6Addr, result) + + // Destination should be rewritten to the original source + gotDstIP, _ := netip.AddrFromSlice(pkt[16:20]) + assert.Equal(t, origSrcIP, gotDstIP) + }) + + t.Run("no conntrack entry", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPReply) + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + + result := fw.unSnat(pkt, &fp) + assert.False(t, result.IsValid()) + }) +} + +func TestFirewall_Drop_SNATFullFlow(t *testing.T) { + // Integration test: a complete SNAT flow through Drop + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + myV6Prefix := netip.MustParsePrefix("fd00::1/128") + unsafeNet := netip.MustParsePrefix("192.168.0.0/16") + + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{myV6Prefix}, + unsafeNetworks: []netip.Prefix{unsafeNet}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw.snatAddr = snatAddr + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) + + // Set up the peer: an IPv6-only host sending IPv4 traffic + peerV6Addr := netip.MustParseAddr("fd00::2") + peerCert := &dummyCert{ + name: "peer", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::2/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(myV6Prefix) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV6Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + pkt := slices.Clone(canonicalUDPSnatMe) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + // Drop should succeed and SNAT the packet + err := fw.Drop(fp, pkt, true, h, cp, nil) + require.NoError(t, err) + + // After Drop, the source should be rewritten to the snat addr + gotSrcIP, _ := netip.AddrFromSlice(pkt[12:16]) + assert.Equal(t, snatAddr, gotSrcIP) +} + +func TestHasOnlyV6Addresses(t *testing.T) { + t.Run("v6 only", func(t *testing.T) { + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("fd00::1"), + netip.MustParseAddr("fd00::2"), + }} + assert.True(t, h.HasOnlyV6Addresses()) + }) + + t.Run("v4 only", func(t *testing.T) { + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("10.0.0.1"), + }} + assert.False(t, h.HasOnlyV6Addresses()) + }) + + t.Run("mixed v4 and v6", func(t *testing.T) { + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("fd00::1"), + netip.MustParseAddr("10.0.0.1"), + }} + assert.False(t, h.HasOnlyV6Addresses()) + }) +} + +// --- Adversarial SNAT Tests --- + +func TestFirewall_ApplySnat_CrossHostHijack(t *testing.T) { + // Host A (fd00::1) establishes SNAT flow. Host B (fd00::2) sends a packet + // matching the same conntrack key but with a different identity. + // applySnat must reject with ErrSNATIdentityMismatch and leave the packet unmodified. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + hostA := netip.MustParseAddr("fd00::1") + hostB := netip.MustParseAddr("fd00::2") + + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + // Simulate Host A having established a flow + cn := &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 12345), + SrcVpnIp: hostA, + SnatPort: 55555, + }, + } + + // Host B tries to reuse the same conntrack entry + pkt := slices.Clone(canonicalUDPHijack) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + hB := &HostInfo{vpnAddrs: []netip.Addr{hostB}} + + err := fw.applySnat(pkt, &fp, cn, hB) + assert.ErrorIs(t, err, ErrSNATIdentityMismatch) + assert.Equal(t, canonicalUDPHijack, pkt, "packet bytes must be unmodified after identity mismatch") +} + +func TestFirewall_ApplySnat_MixedStackRejected(t *testing.T) { + // A host with both v4 and v6 VPN addresses should never get SNAT treatment. + // Test both orderings of vpnAddrs to verify behavior. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + dstIP := netip.MustParseAddr("192.168.1.1") + + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + + t.Run("v6 first then v4", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + // Mixed-stack: v6 first. applySnat checks vpnAddrs[0].Is6() which is true, + // so it would create a flow. But the caller (Drop) guards with HasOnlyV6Addresses(). + // This test documents that applySnat alone doesn't prevent mixed-stack SNAT. + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("fd00::1"), + netip.MustParseAddr("10.0.0.50"), + }} + + err := fw.applySnat(pkt, &fp, cn, h) + // applySnat only checks vpnAddrs[0].Is6(), so this succeeds. + // The real guard is in Drop() via HasOnlyV6Addresses(). + assert.NoError(t, err, "applySnat alone allows v6-first mixed-stack (guarded by Drop)") + }) + + t.Run("v4 first then v6", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + // Mixed-stack: v4 first. vpnAddrs[0].Is6() is false -> ErrCannotSNAT. + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("10.0.0.50"), + netip.MustParseAddr("fd00::1"), + }} + + err := fw.applySnat(pkt, &fp, cn, h) + assert.ErrorIs(t, err, ErrCannotSNAT) + assert.Equal(t, canonicalUDPTest, pkt, "packet bytes must be unmodified on error") + }) +} + +func TestFirewall_ApplySnat_PacketUnmodifiedOnError(t *testing.T) { + // When applySnat returns an error, the packet must not be partially rewritten. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + dstIP := netip.MustParseAddr("192.168.1.1") + + t.Run("no snatAddr configured", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + + pkt := slices.Clone(canonicalUDPTest) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("fd00::1")}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.Error(t, err) + assert.Equal(t, canonicalUDPTest, pkt, "packet must be byte-for-byte identical after error") + }) + + t.Run("identity mismatch", func(t *testing.T) { + snatAddr := netip.MustParseAddr("169.254.55.96") + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 12345), + SrcVpnIp: netip.MustParseAddr("fd00::99"), + SnatPort: 55555, + }, + } + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("fd00::1")}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.ErrorIs(t, err, ErrSNATIdentityMismatch) + assert.Equal(t, canonicalUDPTest, pkt, "packet must be byte-for-byte identical after identity mismatch") + }) + + t.Run("v4 host rejected", func(t *testing.T) { + snatAddr := netip.MustParseAddr("169.254.55.96") + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.50")}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.ErrorIs(t, err, ErrCannotSNAT) + assert.Equal(t, canonicalUDPTest, pkt, "packet must be byte-for-byte identical after v4 host rejection") + }) +} + +func TestFirewall_UnSnat_NonSNATConntrack(t *testing.T) { + // A conntrack entry exists but has snat=nil. unSnat should return an invalid addr + // and not rewrite the packet. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + // Create a conntrack entry with snat=nil (a normal non-SNAT connection) + snatFP := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + fw.Conntrack.Lock() + fw.Conntrack.Conns[snatFP] = &conn{ + snat: nil, // deliberately nil + } + fw.Conntrack.Unlock() + + pkt := slices.Clone(canonicalUDPReply) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + + result := fw.unSnat(pkt, &fp) + assert.False(t, result.IsValid(), "unSnat should return invalid addr for non-SNAT conntrack entry") + assert.Equal(t, canonicalUDPReply, pkt, "packet must not be rewritten when conntrack has no snat info") +} + +func TestFirewall_Drop_FirewallBlocksSNAT(t *testing.T) { + // Firewall rules only allow port 80. An SNAT-eligible packet to port 443 + // must be rejected with ErrNoMatchingRule BEFORE any SNAT rewriting occurs. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw.snatAddr = snatAddr + // Only allow port 80 inbound + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 80, []string{"any"}, "", "", "any", "", "")) + + peerV6Addr := netip.MustParseAddr("fd00::2") + peerCert := &dummyCert{ + name: "peer", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::2/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::1/128")) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV6Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + // Send to port 443 (not allowed) + pkt := slices.Clone(canonicalUDPBlocked) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 443, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + err := fw.Drop(fp, pkt, true, h, cp, nil) + assert.ErrorIs(t, err, ErrNoMatchingRule, "firewall should block SNAT-eligible traffic that doesn't match rules") + assert.Equal(t, canonicalUDPBlocked, pkt, "packet must not be rewritten when firewall blocks it") +} + +func TestFirewall_Drop_SNATLocalAddrNotRoutable(t *testing.T) { + // An SNAT peer sends IPv4 traffic to an address NOT in routableNetworks. + // willingToHandleLocalAddr should reject with ErrInvalidLocalIP. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw.snatAddr = snatAddr + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) + + peerV6Addr := netip.MustParseAddr("fd00::2") + peerCert := &dummyCert{ + name: "peer", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::2/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::1/128")) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV6Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + // Dest 172.16.0.1 is NOT in our routableNetworks (which only has fd00::1/128 and 192.168.0.0/16) + pkt := slices.Clone(canonicalUDPWrongDest) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("172.16.0.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + err := fw.Drop(fp, pkt, true, h, cp, nil) + assert.ErrorIs(t, err, ErrInvalidLocalIP, "traffic to non-routable local address should be rejected") +} + +func TestFirewall_Drop_NoSnatAddrRejectsV6Peer(t *testing.T) { + // Firewall has no snatAddr configured. An IPv6-only peer sends IPv4 traffic. + // allowNetworkType(UncheckedSNATPeer) should reject with ErrInvalidRemoteIP. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, netip.Addr{}) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) + + peerV6Addr := netip.MustParseAddr("fd00::2") + peerCert := &dummyCert{ + name: "peer", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::2/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::1/128")) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV6Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + pkt := slices.Clone(canonicalUDPNoSnat) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + err := fw.Drop(fp, pkt, true, h, cp, nil) + assert.ErrorIs(t, err, ErrInvalidRemoteIP, "v6 peer with no snatAddr should be rejected") +} + +func TestFirewall_Drop_IPv4HostNotSNATted(t *testing.T) { + // An IPv4 VPN host sends IPv4 traffic. Even though the router has snatAddr + // configured and the traffic is IPv4, the firewall must NOT treat this as + // UncheckedSNATPeer. The packet must not be SNAT-rewritten. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + t.Run("v6-only router rejects v4 peer as VPNPeer", func(t *testing.T) { + // When the router is v6-only, the v4 peer's address is outside our VPN + // networks -> classified as NetworkTypeVPNPeer -> rejected (not SNATted). + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw.snatAddr = snatAddr + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) + + peerV4Addr := netip.MustParseAddr("10.128.0.2") + peerCert := &dummyCert{ + name: "v4peer", + networks: []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::1/128")) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV4Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + pkt := slices.Clone(canonicalUDPV4Traffic) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.128.0.2"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + err := fw.Drop(fp, pkt, true, h, cp, nil) + assert.ErrorIs(t, err, ErrPeerRejected, "IPv4 peer should be rejected as VPNPeer, not treated as SNAT") + assert.Equal(t, canonicalUDPV4Traffic, pkt, "packet must not be rewritten when peer is rejected") + }) + + t.Run("identifyNetworkType classifies v4 peer correctly", func(t *testing.T) { + // Directly verify that identifyNetworkType returns the right type for + // an IPv4 peer (not UncheckedSNATPeer). + fw := &Firewall{snatAddr: snatAddr} + + // Simple case: v4 host, no networks table + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.128.0.2")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.128.0.2"), + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + nwType := fw.identifyNetworkType(h, fp) + assert.Equal(t, NetworkTypeVPN, nwType, "v4 peer using its own VPN addr should be NetworkTypeVPN") + assert.NotEqual(t, NetworkTypeUncheckedSNATPeer, nwType, "must NOT be classified as SNAT peer") + }) + + t.Run("identifyNetworkType v4 peer with mismatched source", func(t *testing.T) { + // v4 host sends with a source IP that doesn't match its VPN addr + fw := &Firewall{snatAddr: snatAddr} + + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.128.0.2")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.0.0.99"), // Not the peer's VPN addr + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + nwType := fw.identifyNetworkType(h, fp) + assert.Equal(t, NetworkTypeInvalidPeer, nwType, "v4 peer with mismatched source should be InvalidPeer") + assert.NotEqual(t, NetworkTypeUncheckedSNATPeer, nwType, "must NOT be classified as SNAT peer") + }) +}