From 70399ea533bf1c2960a0be8276b47649b320bd6d Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 14 Jan 2026 12:36:55 -0600 Subject: [PATCH 01/31] use in-Nebula SNAT to send IPv4 UnsafeNetworks traffic over an IPv6 overlay --- SNAT.md | 102 +++++++++++++ cert/cert_v2.go | 2 +- firewall.go | 332 ++++++++++++++++++++++++++++++++++------- firewall/packet.go | 4 + firewall_test.go | 292 ++++++++++++++++++++++-------------- hostmap.go | 12 ++ inside.go | 29 +++- interface.go | 3 +- main.go | 6 +- outside.go | 2 +- overlay/tun.go | 10 +- overlay/tun_android.go | 4 +- overlay/tun_darwin.go | 4 +- overlay/tun_freebsd.go | 4 +- overlay/tun_ios.go | 4 +- overlay/tun_linux.go | 61 ++++++-- overlay/tun_netbsd.go | 4 +- overlay/tun_openbsd.go | 4 +- overlay/tun_tester.go | 4 +- overlay/tun_windows.go | 4 +- overlay/user.go | 2 +- snat.go | 91 +++++++++++ 22 files changed, 770 insertions(+), 210 deletions(-) create mode 100644 SNAT.md create mode 100644 snat.go diff --git a/SNAT.md b/SNAT.md new file mode 100644 index 00000000..230e44a1 --- /dev/null +++ b/SNAT.md @@ -0,0 +1,102 @@ +# Don't merge me + +# Accessing IPv4-only UnsafeNetworks via an IPv6-only overlay + +## Background +Nebula is an VPN-like connectivity solution. It provides what we call an "overlay network", +an additional IP-addressed network on top of one-or-more traditional "underlay networks". +As long as two devices with Nebula certificates (credentials, essentially signed keypairs with metadata) can find each other +and exchange traffic via a common underlay network (often this is the Internet), +they will also be able to exchange traffic securely via a point-to-point, encrypted, authenticated tunnel. + +Typically, all Nebula traffic is strongly associated with the Nebula certificate of the sender +(that is, the source IP of all packets matches the IP listed in the sender's certificate). +However, it is useful to be able to bend this rule. That is why there is another field in the Nebula certificate, named UnsafeNetworks, +which lists the network prefixes that the host bearing this certificate is allowed to "speak for". + +## Problem Statement +We want IPv6-only overlay networks to be able to carry IPv4 traffic to reach off-overlay hosts via UnsafeNetworks + +### Scenario + +To illustrate this scenario, we will define 3 hosts: +* a Phone, running Nebula, assigned the overlay IP fd01::AA/64. It has an undefined underlay, but we assert that it always has working IPv4 OR IPv6 connectivity to Router. +* a Router, running Nebula, assigned the overlay IP fd01::55/64. It has a stable underlay link that Phone can always reach. +* a Printer, which cannot run Nebula, and is only capable of IPv4 communication. It has a direct link to Router, but the Phone cannot reach it directly. + +You, the User, wish to use your Phone to print out something on the Printer while you're away from home. How can we make this possible with your IPv6-only Nebula overlay? +In particular, your Phone may connect to any cellular or public WiFi network, and we cannot control the IP address it will be assigned. If you MUST print, an IP conflict is not acceptable. +Therefore, we cannot simply dismiss this problem by suggesting that you assign a small IPv4 network within your overlay. Sure, it probably works, and in this toy scenario, the odds of a conflict are pretty small. But it scales very poorly. What if a whole company needs to use this printer (or perhaps a less contrived need?) +We can do better. + +## Solution + +* Even though Phone and Router lack IPv4 assignments, we can still put V4 addresses on their tun devices. +* Each overlay host who wishes to use this feature shall select (or configure?) an assignment within 169.254.0.0/16, the IPv4 link-local range + * this is a pretty small space, but it confines the region of IP conflict to a much smaller domain. And, because overlay hosts will never dial one another with this address, cryptographic verification of it via the certificate is less important. + * On Phone, Nebula will configure an unsafe_route to the Printer using this address. Because it is a device route, we do not need to tell the operating system the address of the next hop (no `via`) + * On Router, Nebula will use this address to masquerade packets from Phone. You'll see! +* Let's walk through setting up a TCP session between Phone and Printer in this scheme: + * Phone sends SYN to the printer's underlay IPv4 address + * This packet lands on Phone's Nebula tun device + * Nebula locates Router as the destination for this packet, as defined in `tun.unsafe_routes` + * Nebula checks the packet against the outbound chain: + * the destination IP of Printer is listed in Router's UnsafeNetworks, so that check will pass + * Phone's source IP is not listed in any certificate, but because the destination address is of `NetworkTypeUnsafe` and this is an outgoing flow, we keep going + * Actual outbound firewall rules get checked, assume they pass + * conntrack entry created to permit replies + * Phone encrypts the packet and sends it to Router + * Router gets the packet from Phone, and decrypts it. It is passed to the Nebula firewall for evaluation: + * `firewall.Drop()` on the Router's Nebula inbound rules: + * Because Router is configured to allow SNAT, and this packet is an IPv4 packet from a IPv6-only host, the firewall module enters "snat mode" (`TODO explain?`) + * This is a new flow, so the conntrack lookup for it fails + * `firewall.identifyNetworkType()` + * identify what "kind" of remote address (this is the inbound firewall, so the remote address is the packet's src IP) we've been given + * `NetworkTypeVPN`, for example is a remote address that matches the peer's certificate + * In this case, because the traffic is IPv4 traffic flowing from an IPv6-only host, and we've opted into supporting SNAT, this traffic is marked as `NetworkTypeUncheckedSNATPeer` + * `firewall.allowNetworkType()` will allow `NetworkTypeUncheckedSNATPeer` traffic to proceed because we have opted into SNAT + * `firewall.willingToHandleLocalAddr()` now needs to check if we're willing to accept the destination address + * Because this traffic is addressed to a destination listed in our UnsafeNetworks, it's considered "routable" and passes this check + * Nebula's firewall rules are evaluated as normal. In particular, the `cidr` parameter will be checked against the IPv4 address, NOT the IPv6 address in the Phone's certificate + * @Nate I think this is "correct", but could be a source of footgun + * Let's assume the Nebula rules accept the traffic + * We create a conntrack entry for this flow + * We do not want to transmit with the IPv4 source IP we got from Phone. We don't want the Phone's IP assignments (in this scheme) to enter the network-space on Router at all. + * To this end, we rewrite the source address (and port, if needed) to our own pre-selected IPv4 link-local address. This address will never actually leave the machine, but we need it so return traffic can be routed back to the nebula tun on Router + * Replace source IP with "Router's SNAT address" + * Look in our conntrack table, and ensure we do not already have a flow that matches this srcip/srcport/dstip/dstport/proto tuple + * if we do, increment srcport until we find an empty slot. Only use ephemeral ports. This gives 0x7ff flows per dstip/dstport/proto tuple, which ought to be plenty. + * Record the original srcip/srcport as part of the conntrack data for later + * Fix checksums + * Nebula writes the rewritten packet to Router's tun + * netfilter picks up the packet. In this example, Router is using `iptables`. A rule in the `nat` table similar to `-A POSTROUTING -d PRINTER_UNDERLAY_IP_HERE/32 -j MASQUERADE` is hit + * This ensures that "Router's SNAT address" never actually leaves Router. + * The packet leaves Router, and hits Printer + * Printer gleefully accepts the SYN from Router, and replies with an ACK + * iptables on Router de-masquerades the packet, and delivers it to the Nebula tun + * Nebula reads the packet off the tun. Because it came from the tun, and not UDP, remember that this is considered "inside" traffic and will be evaluated as "outbound" traffic by Nebula. + * Because this is inside traffic, it needs to be associated with a HostInfo before we can pass it to the firewall. + * Check that the packet is addressed to the "Router's SNAT address". If so, attempt to un-SNAT by "peeking" into conntrack + * If a Router needs to speak to _another_ Router with v4-in-v6 unsafe_routes like this, it _must_ use a distinct address from the "Router's SNAT address" + * the easy way on Linux to assure this is to set a route for the "Router SNAT address" to the Nebula tun, but not actually assign the address + * The "peek" into conntrack succeeds, and we find everything we need to rewrite the packet for transmission to Phone, as well as Phone's overlay IP, which lets us locate Phone's HostInfo + * The packet is rewritten, replacing the destination address/port to match the ones Phone expects + * checksums corrected + * Check the Nebula firewall, and see that we have a valid conntrack entry (wow!) + * we could _technically_ skip this check, but I dislike not passing all traffic we intend to accept into `firewall.Drop()`. The second conntrack locking-lookup does suck. There's room for improvement here. + * The traffic is accepted, encrypted, and sent back to Phone + * Phone gets the packet from Router, decrypts it, checks the firewall + * we have a conntrack entry for this flow, so the firewall accepts it, and delivers it to the tun + * Both sides now have a nice conntrack entry, and traffic should continue to flow uninterrupted until it expires + +This conntrack entry technically creates a risk though. Let's examine that. +The Phone will accept inbound traffic matching the conntrack spec from any Router-like host authorized to speak for that UnsafeRoute, not just Router. In theory, this is desireable, and the risk is mitigated by accepting/trusting Nebula's certificate model. +There's a good chance that if you "switch" from one Router to another, you'll lose your session on your Printer-like host. Such is life under NAT! + +Can the Router be exploited somehow? +* an attacker that shares a network with Printer would be able to spoof traffic as if they are Printer. This is the same risk as UnsafeNetworks today. +* an attacker on the overlay would have their traffic evaluated as inbound + * if they try to tx on the same source IP as Phone, SNAT will assign a different port + * if they try to send inbound traffic that matches the un-masqueraded traffic iptables would have delivered + * conntrack will accept the packet, but before we finish firewalling and return, is the applySnat step + * this will fail because the hostinfo that sent the packet does not contain the vpnip that is associated with the snat entry \ No newline at end of file diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 4648c496..87d1ec11 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -441,7 +441,7 @@ func (c *certificateV2) validate() error { } } else if network.Addr().Is4() { if !hasV4Networks { - return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network) + //return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network) } } } diff --git a/firewall.go b/firewall.go index 2d67acbb..2d8cf49e 100644 --- a/firewall.go +++ b/firewall.go @@ -2,6 +2,7 @@ package nebula import ( "crypto/sha256" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -22,10 +23,29 @@ import ( "github.com/slackhq/nebula/firewall" ) +var ErrCannotSNAT = errors.New("cannot snat this packet") +var ErrSNATIdentityMismatch = errors.New("refusing to SNAT for mismatched host") + type FirewallInterface interface { AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error } +type snatInfo struct { + //Src is the source IP+port to write into unsafe-route-bound packet + Src netip.AddrPort + //SrcVpnIp is the overlay IP associated with this flow. It's needed to associate reply traffic so we can get it back to the right host. + SrcVpnIp netip.Addr + //SnatPort is the port to rewrite into an overlay-bound packet + SnatPort uint16 +} + +func (s *snatInfo) Valid() bool { + if s == nil { + return false + } + return s.Src.IsValid() +} + type conn struct { Expires time.Time // Time when this conntrack entry will expire @@ -34,6 +54,9 @@ type conn struct { // fields pack for free after the uint32 above incoming bool rulesVersion uint16 + + //for SNAT support + snat *snatInfo } // TODO: need conntrack max tracked connections handling @@ -66,6 +89,7 @@ type Firewall struct { defaultLocalCIDRAny bool incomingMetrics firewallMetrics outgoingMetrics firewallMetrics + snatAddr netip.Addr l *logrus.Logger } @@ -83,6 +107,15 @@ type FirewallConntrack struct { TimerWheel *TimerWheel[firewall.Packet] } +func (ct *FirewallConntrack) dupeConnUnlocked(fp firewall.Packet, c *conn, timeout time.Duration) { + if _, ok := ct.Conns[fp]; !ok { + ct.TimerWheel.Advance(time.Now()) + ct.TimerWheel.Add(fp, timeout) + } + + ct.Conns[fp] = c +} + // FirewallTable is the entry point for a rule, the evaluation order is: // Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR) type FirewallTable struct { @@ -131,7 +164,7 @@ type firewallLocalCIDR struct { // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // The certificate provided should be the highest version loaded in memory. -func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { +func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate, snatAddr netip.Addr) *Firewall { //TODO: error on 0 duration var tmin, tmax time.Duration @@ -149,12 +182,14 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D tmax = defaultTimeout } + hasV4Networks := false routableNetworks := new(bart.Lite) var assignedNetworks []netip.Prefix for _, network := range c.Networks() { nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) routableNetworks.Insert(nprefix) assignedNetworks = append(assignedNetworks, network) + hasV4Networks = hasV4Networks || network.Addr().Is4() } hasUnsafeNetworks := false @@ -163,6 +198,10 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D hasUnsafeNetworks = true } + if !hasUnsafeNetworks || hasV4Networks { + snatAddr = netip.Addr{} //disable using the special snat address if it doesn't make sense to use it + } + return &Firewall{ Conntrack: &FirewallConntrack{ Conns: make(map[firewall.Packet]*conn), @@ -176,6 +215,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D routableNetworks: routableNetworks, assignedNetworks: assignedNetworks, hasUnsafeNetworks: hasUnsafeNetworks, + snatAddr: snatAddr, l: l, incomingMetrics: firewallMetrics{ @@ -191,7 +231,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } } -func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C, snatAddr netip.Addr) (*Firewall, error) { certificate := cs.getCertificate(cert.Version2) if certificate == nil { certificate = cs.getCertificate(cert.Version1) @@ -201,14 +241,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew panic("No certificate available to reconfigure the firewall") } - fw := NewFirewall( - l, - c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), - c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), - c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), - certificate, - //TODO: max_connections - ) + fw := NewFirewall(l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), certificate, snatAddr) fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) @@ -314,6 +347,11 @@ func (f *Firewall) GetRuleHashes() string { return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) } +func (f *Firewall) ShouldUnSNAT(fp *firewall.Packet) bool { + // f.snatAddr is only valid if we're a snat-capable router + return f.snatAddr.IsValid() && fp.RemoteAddr == f.snatAddr +} + func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { var table string if inbound { @@ -414,50 +452,204 @@ var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses") var ErrNoMatchingRule = errors.New("no matching rule in firewall table") +func (f *Firewall) unSnat(data []byte, fp *firewall.Packet) netip.Addr { + c := f.peek(*fp) //unfortunately this needs to lock. Surely there's a better way. + if c == nil { + return netip.Addr{} + } + if !c.snat.Valid() { + return netip.Addr{} + } + oldIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort) + rewritePacket(data, fp, oldIP, c.snat.Src, 16, 2) + return c.snat.SrcVpnIp +} + +func rewritePacket(data []byte, fp *firewall.Packet, oldIP netip.AddrPort, newIP netip.AddrPort, ipOffset int, portOffset int) { + //change address + copy(data[ipOffset:], newIP.Addr().AsSlice()) + recalcIPv4Checksum(data, oldIP.Addr(), newIP.Addr()) + ipHeaderLen := int(data[0]&0x0F) * 4 + + switch fp.Protocol { + case firewall.ProtoICMP: + binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], newIP.Port()) //we use the ID field as a "port" for ICMP + icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on code yet (but Linux would) + recalcICMPv4Checksum(data, icmpCode, icmpCode, oldIP.Port(), newIP.Port()) + case firewall.ProtoUDP: + dstport := ipHeaderLen + portOffset + binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port()) + recalcUDPv4Checksum(data, oldIP, newIP) + case firewall.ProtoTCP: + dstport := ipHeaderLen + portOffset + binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port()) + recalcTCPv4Checksum(data, oldIP, newIP) + } +} + +func (f *Firewall) findUsableSNATPort(fp *firewall.Packet, c *conn) error { + oldPort := fp.RemotePort + conntrack := f.Conntrack + conntrack.Lock() + defer conntrack.Unlock() + for numPortsChecked := 0; numPortsChecked < 0x7ff; numPortsChecked++ { + _, ok := conntrack.Conns[*fp] + if !ok { + //yay, we can use this port + //track the snatted flow with the same expiration as the unsnatted version + conntrack.dupeConnUnlocked(*fp, c, f.packetTimeout(*fp)) + return nil + } + //increment and retry. There's probably better strategies out there + fp.RemotePort++ + if fp.RemotePort < 0x7ff { + fp.RemotePort += 0x7ff // keep it ephemeral for now + } + } + + //if we made it here, we failed + fp.RemotePort = oldPort + return ErrCannotSNAT +} + +func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) error { + if !f.snatAddr.IsValid() { + return ErrCannotSNAT + } + if c.snat.Valid() { + //old flow: make sure it came from the right place + if !slices.Contains(hostinfo.vpnAddrs, c.snat.SrcVpnIp) { + return ErrSNATIdentityMismatch + } + fp.RemoteAddr = f.snatAddr + fp.RemotePort = c.snat.SnatPort + } else if hostinfo.vpnAddrs[0].Is6() { + //we got a new flow + c.snat = &snatInfo{ + Src: netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort), + SrcVpnIp: hostinfo.vpnAddrs[0], + } + fp.RemoteAddr = f.snatAddr + //find a new port to use, if needed + err := f.findUsableSNATPort(fp, c) + if err != nil { + c.snat = nil + return err + } + c.snat.SnatPort = fp.RemotePort //may have been updated inside f.findUsableSNATPort + } else { + return ErrCannotSNAT + } + + newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort) + rewritePacket(data, fp, c.snat.Src, newIP, 12, 0) + return nil +} + +func (f *Firewall) identifyNetworkType(h *HostInfo, fp firewall.Packet) NetworkType { + if h.networks == nil { + // Simple case: Certificate has one address and no unsafe networks + if h.vpnAddrs[0] == fp.RemoteAddr { + return NetworkTypeVPN + } else if fp.IsIPv4() && h.HasOnlyV6Addresses() { + return NetworkTypeUncheckedSNATPeer + } else { + return NetworkTypeInvalidPeer + } + } else if nwType, ok := h.networks.Lookup(fp.RemoteAddr); ok { + //todo check for if fp.RemoteAddr is our f.snatAddr here too? Does that need a special case? + return nwType //will return NetworkTypeVPN or NetworkTypeUnsafe + } else if fp.IsIPv4() && h.HasOnlyV6Addresses() { //todo surely I'm smart enough to avoid writing these branches twice + return NetworkTypeUncheckedSNATPeer + } else { + return NetworkTypeInvalidPeer + } +} + +func (f *Firewall) allowNetworkType(nwType NetworkType) error { + switch nwType { + case NetworkTypeVPN: + return nil + case NetworkTypeInvalidPeer: + return ErrInvalidRemoteIP + case NetworkTypeVPNPeer: + //todo we might need a specialSnatMode case in here to handle routers with v4 addresses when we don't also have a v4 address? + return ErrPeerRejected // reject for now, one day this may have different FW rules + case NetworkTypeUnsafe: + return nil // nothing special, one day this may have different FW rules + case NetworkTypeUncheckedSNATPeer: + if f.snatAddr.IsValid() { + return nil //todo is this enough? + } else { + return ErrInvalidRemoteIP + } + default: + return ErrUnknownNetworkType //should never happen + } +} + +func (f *Firewall) willingToHandleLocalAddr(incoming bool, fp firewall.Packet, remoteNwType NetworkType) error { + if f.routableNetworks.Contains(fp.LocalAddr) { + return nil //easy, this should handle NetworkTypeVPN in all cases, and NetworkTypeUnsafe on the router side + } + + //watch out, when incoming, this function decides if we will deliver a packet locally + //when outgoing, much less important, it just decides if we're willing to tx + switch remoteNwType { + // we never want to accept unconntracked inbound traffic from these network types, but outbound is okay. + // It's the recipient's job to validate and accept or deny the packet. + case NetworkTypeUncheckedSNATPeer, NetworkTypeUnsafe: + //NetworkTypeUnsafe needed here to allow inbound from an unsafe-router + if incoming { + return ErrInvalidLocalIP + } + return nil + default: + return ErrInvalidLocalIP + } +} + // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { +func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { + table := f.OutRules + if incoming { + table = f.InRules + } + + snatmode := fp.IsIPv4() && h.HasOnlyV6Addresses() && f.snatAddr.IsValid() + if snatmode { + //if this is an IPv4 packet from a V6 only host, and we're configured to snat that kind of traffic, it must be snatted, + //so it can never be in the localcache, which lacks SNAT data + //nil out the pointer to avoid ever using it + localCache = nil + } + // Check if we spoke to this tuple, if we did then allow this packet - if f.inConns(fp, h, caPool, localCache) { + if localCache != nil { + if _, ok := localCache[fp]; ok { + return nil //packet matched the cache, we're not snatting, we can return early + } + } + c := f.inConns(fp, h, caPool, localCache) + if c != nil { + if incoming && snatmode { + return f.applySnat(pkt, &fp, c, h) + } return nil } // Make sure remote address matches nebula certificate, and determine how to treat it - if h.networks == nil { - // Simple case: Certificate has one address and no unsafe networks - if h.vpnAddrs[0] != fp.RemoteAddr { - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrInvalidRemoteIP - } - } else { - nwType, ok := h.networks.Lookup(fp.RemoteAddr) - if !ok { - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrInvalidRemoteIP - } - switch nwType { - case NetworkTypeVPN: - break // nothing special - case NetworkTypeVPNPeer: - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrPeerRejected // reject for now, one day this may have different FW rules - case NetworkTypeUnsafe: - break // nothing special, one day this may have different FW rules - default: - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrUnknownNetworkType //should never happen - } + remoteNetworkType := f.identifyNetworkType(h, fp) + if err := f.allowNetworkType(remoteNetworkType); err != nil { + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return err } // Make sure we are supposed to be handling this local ip address - if !f.routableNetworks.Contains(fp.LocalAddr) { + if err := f.willingToHandleLocalAddr(incoming, fp, remoteNetworkType); err != nil { f.metrics(incoming).droppedLocalAddr.Inc(1) - return ErrInvalidLocalIP - } - - table := f.OutRules - if incoming { - table = f.InRules + return err } // We now know which firewall table to check against @@ -467,9 +659,14 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // We always want to conntrack since it is a faster operation - f.addConn(fp, incoming) + c = f.addConn(fp, incoming) - return nil + if incoming && remoteNetworkType == NetworkTypeUncheckedSNATPeer { + return f.applySnat(pkt, &fp, c, h) + } else { + //outgoing snat is handled before this function is called + return nil + } } func (f *Firewall) metrics(incoming bool) firewallMetrics { @@ -496,12 +693,23 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { - if localCache != nil { - if _, ok := localCache[fp]; ok { - return true - } +func (f *Firewall) peek(fp firewall.Packet) *conn { + conntrack := f.Conntrack + conntrack.Lock() + + // Purge every time we test + ep, has := conntrack.TimerWheel.Purge() + if has { + f.evict(ep) } + + c := conntrack.Conns[fp] + + conntrack.Unlock() + return c +} + +func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn { conntrack := f.Conntrack conntrack.Lock() @@ -515,7 +723,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, if !ok { conntrack.Unlock() - return false + return nil } if c.rulesVersion != f.rulesVersion { @@ -538,7 +746,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, } delete(conntrack.Conns, fp) conntrack.Unlock() - return false + return nil } if f.l.Level >= logrus.DebugLevel { @@ -568,12 +776,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache[fp] = struct{}{} } - return true + return c } -func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { +func (f *Firewall) packetTimeout(fp firewall.Packet) time.Duration { var timeout time.Duration - c := &conn{} switch fp.Protocol { case firewall.ProtoTCP: @@ -583,7 +790,13 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { default: timeout = f.DefaultTimeout } + return timeout +} +func (f *Firewall) addConn(fp firewall.Packet, incoming bool) *conn { + c := &conn{} + + timeout := f.packetTimeout(fp) conntrack := f.Conntrack conntrack.Lock() if _, ok := conntrack.Conns[fp]; !ok { @@ -597,7 +810,9 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { c.rulesVersion = f.rulesVersion c.Expires = time.Now().Add(timeout) conntrack.Conns[fp] = c + conntrack.Unlock() + return c } // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel @@ -682,6 +897,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer var port int32 + if p.Protocol == firewall.ProtoICMP { + // port numbers are re-used for connection tracking and SNAT, + // but we don't want to actually filter on them for ICMP + // ICMP6 is omitted because we don't attempt to parse code/identifier/etc out of ICMP6 + return fp[firewall.PortAny].match(p, c, caPool) + } + if p.Fragment { port = firewall.PortFragment } else if incoming { diff --git a/firewall/packet.go b/firewall/packet.go index 2cbfb5ea..fac3baab 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -31,6 +31,10 @@ type Packet struct { Fragment bool } +func (fp *Packet) IsIPv4() bool { + return fp.LocalAddr.Is4() && fp.RemoteAddr.Is4() +} + func (fp *Packet) Copy() *Packet { return &Packet{ LocalAddr: fp.LocalAddr, diff --git a/firewall_test.go b/firewall_test.go index a2133760..282929f2 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -21,7 +21,7 @@ import ( func TestNewFirewall(t *testing.T) { l := test.NewLogger() c := &dummyCert{} - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) conntrack := fw.Conntrack assert.NotNil(t, conntrack) assert.NotNil(t, conntrack.Conns) @@ -36,23 +36,23 @@ func TestNewFirewall(t *testing.T) { assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c) + fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c, netip.Addr{}) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c) + fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c, netip.Addr{}) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c) + fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c, netip.Addr{}) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c) + fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c, netip.Addr{}) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c) + fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c, netip.Addr{}) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) } @@ -63,7 +63,7 @@ func TestFirewall_AddRule(t *testing.T) { l.SetOutput(ob) c := &dummyCert{} - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) @@ -79,56 +79,56 @@ func TestFirewall_AddRule(t *testing.T) { assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) //no matter what port is given for icmp, it should end up as "any" assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any) assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups) assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1") - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) anyIp, err := netip.ParsePrefix("0.0.0.0/0") require.NoError(t, err) @@ -139,7 +139,7 @@ func TestFirewall_AddRule(t *testing.T) { table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9")) assert.False(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) anyIp6, err := netip.ParsePrefix("::/0") require.NoError(t, err) @@ -150,28 +150,28 @@ func TestFirewall_AddRule(t *testing.T) { table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1")) assert.False(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", "")) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", "")) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) } @@ -208,49 +208,49 @@ func TestFirewall_Drop(t *testing.T) { } h.buildNetworks(myVpnNetworksTable, &c) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr p.RemoteAddr = netip.MustParseAddr("1.2.3.10") - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func TestFirewall_DropV6(t *testing.T) { @@ -287,49 +287,49 @@ func TestFirewall_DropV6(t *testing.T) { } h.buildNetworks(myVpnNetworksTable, &c) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr p.RemoteAddr = netip.MustParseAddr("fd12::56") - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -532,15 +532,15 @@ func TestFirewall_Drop2(t *testing.T) { } h1.buildNetworks(myVpnNetworksTable, c1.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups - require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) + require.ErrorIs(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { @@ -612,24 +612,24 @@ func TestFirewall_Drop3(t *testing.T) { } h3.buildNetworks(myVpnNetworksTable, c3.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match - require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h2, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) - assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule) // Test a remote address match - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) - require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil)) } func TestFirewall_Drop3V6(t *testing.T) { @@ -664,10 +664,10 @@ func TestFirewall_Drop3V6(t *testing.T) { h.buildNetworks(myVpnNetworksTable, c.Certificate) // Test a remote address match - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) cp := cert.NewCAPool() require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -704,35 +704,35 @@ func TestFirewall_DropConntrackReload(t *testing.T) { } h.buildNetworks(myVpnNetworksTable, c.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw := fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw = fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 // Drop outbound because conntrack doesn't match new ruleset - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule) } func TestFirewall_ICMPPortBehavior(t *testing.T) { @@ -771,19 +771,19 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { } t.Run("ICMP allowed", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", "")) t.Run("zero ports", func(t *testing.T) { p := templ.Copy() p.LocalPort = 0 p.RemotePort = 0 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) //now also allow outbound - require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) }) t.Run("nonzero ports", func(t *testing.T) { @@ -791,29 +791,29 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0xabcd p.RemotePort = 0x1234 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) //now also allow outbound - require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) }) }) t.Run("Any proto, some ports allowed", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", "")) t.Run("zero ports, still blocked", func(t *testing.T) { p := templ.Copy() p.LocalPort = 0 p.RemotePort = 0 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) //now also allow outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) }) t.Run("nonzero ports, still blocked", func(t *testing.T) { @@ -821,12 +821,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0xabcd p.RemotePort = 0x1234 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) //now also allow outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) }) t.Run("nonzero, matching ports, still blocked", func(t *testing.T) { @@ -834,16 +834,16 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 80 p.RemotePort = 80 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) //now also allow outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) }) }) t.Run("Any proto, any port", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) t.Run("zero ports, allowed", func(t *testing.T) { resetConntrack(fw) @@ -851,12 +851,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0 p.RemotePort = 0 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) //now also allow outbound - require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) }) t.Run("nonzero ports, allowed", func(t *testing.T) { @@ -865,15 +865,15 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0xabcd p.RemotePort = 0x1234 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) //now also allow outbound - require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) //different ID is blocked p.RemotePort++ - require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + require.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) }) }) @@ -908,7 +908,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { } h1.buildNetworks(myVpnNetworksTable, c1.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -922,7 +922,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { Protocol: firewall.ProtoUDP, Fragment: false, } - assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrInvalidRemoteIP) } func BenchmarkLookup(b *testing.B) { @@ -1047,53 +1047,53 @@ func TestNewFirewallFromConfig(t *testing.T) { conf := config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": "asdf"} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } @@ -1336,7 +1336,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) { t.Helper() cp := cert.NewCAPool() resetConntrack(fw) - err := fw.Drop(c.p, true, c.h, cp, nil) + err := fw.Drop(c.p, nil, true, c.h, cp, nil) if c.err == nil { require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr) } else { @@ -1344,7 +1344,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) { } } -func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase { +func buildHostinfo(setup testsetup, theirPrefixes ...netip.Prefix) *HostInfo { c1 := dummyCert{ name: "host1", networks: theirPrefixes, @@ -1364,6 +1364,11 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te h.vpnAddrs[i] = theirPrefixes[i].Addr() } h.buildNetworks(setup.myVpnNetworksTable, &c1) + return &h +} + +func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase { + h := buildHostinfo(setup, theirPrefixes...) p := firewall.Packet{ LocalAddr: setup.c.Networks()[0].Addr(), //todo? RemoteAddr: theirPrefixes[0].Addr(), @@ -1373,9 +1378,9 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te Fragment: false, } return testcase{ - h: &h, + h: h, p: p, - c: &c1, + c: h.ConnectionState.peerCert.Certificate, err: err, } } @@ -1397,12 +1402,25 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse return newSetupFromCert(t, l, c) } +func newSnatSetup(t *testing.T, l *logrus.Logger, myPrefix netip.Prefix, snatAddr netip.Addr) testsetup { + c := dummyCert{ + name: "me", + networks: []netip.Prefix{myPrefix}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + out := newSetupFromCert(t, l, c) + out.fw.snatAddr = snatAddr + return out +} + func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { myVpnNetworksTable := new(bart.Lite) for _, prefix := range c.Networks() { myVpnNetworksTable.Insert(prefix) } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) return testsetup{ @@ -1532,3 +1550,59 @@ func resetConntrack(fw *Firewall) { fw.Conntrack.Conns = map[firewall.Packet]*conn{} fw.Conntrack.Unlock() } + +func TestFirewall_SNAT(t *testing.T) { + t.Parallel() + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + cp := cert.NewCAPool() + myPrefix := netip.MustParsePrefix("1.1.1.1/8") + + MyCert := dummyCert{ + name: "me", + networks: []netip.Prefix{myPrefix}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + theirPrefix := netip.MustParsePrefix("1.2.2.2/8") + snatAddr := netip.MustParseAddr("169.254.55.96") + t.Run("allow inbound all matching", func(t *testing.T) { + t.Parallel() + myCert := MyCert.Copy() + setup := newSnatSetup(t, l, myPrefix, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + resetConntrack(setup.fw) + h := buildHostinfo(setup, theirPrefix) + p := firewall.Packet{ + LocalAddr: setup.c.Networks()[0].Addr(), //todo? + RemoteAddr: h.vpnAddrs[0], + LocalPort: 10, + RemotePort: 90, + Protocol: firewall.ProtoUDP, + Fragment: false, + } + require.NoError(t, setup.fw.Drop(p, nil, true, h, cp, nil)) + }) + //t.Run("allow inbound unsafe route", func(t *testing.T) { + // t.Parallel() + // unsafePrefix := netip.MustParsePrefix("192.168.0.0/24") + // c := dummyCert{ + // name: "me", + // networks: []netip.Prefix{myPrefix}, + // unsafeNetworks: []netip.Prefix{unsafePrefix}, + // groups: []string{"default-group"}, + // issuer: "signer-shasum", + // } + // unsafeSetup := newSetupFromCert(t, l, c) + // tc := buildTestCase(unsafeSetup, nil, twoPrefixes...) + // tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3") + // tc.err = ErrNoMatchingRule + // tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off + // require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", "")) + // tc.err = nil + // tc.Test(t, unsafeSetup.fw) //should pass + //}) +} diff --git a/hostmap.go b/hostmap.go index 7e2939e0..ff5ee456 100644 --- a/hostmap.go +++ b/hostmap.go @@ -224,6 +224,9 @@ const ( NetworkTypeVPNPeer // NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks() NetworkTypeUnsafe + // NetworkTypeUncheckedSNATPeer is used to indicate traffic we're willing to route, but never deliver to a NetworkTypeVPN + NetworkTypeUncheckedSNATPeer + NetworkTypeInvalidPeer ) type HostInfo struct { @@ -277,6 +280,15 @@ type HostInfo struct { lastUsed time.Time } +func (i *HostInfo) HasOnlyV6Addresses() bool { + for _, vpnIp := range i.vpnAddrs { + if !vpnIp.Is6() { + return false + } + } + return true +} + type ViaSender struct { UdpAddr netip.AddrPort relayHI *HostInfo // relayHI is the host info object of the relay diff --git a/inside.go b/inside.go index 0d53f952..3f7cd19e 100644 --- a/inside.go +++ b/inside.go @@ -48,9 +48,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { - hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) - }) + hostinfo, ready := f.getHostinfo(packet, fwPacket) if hostinfo == nil { f.rejectInside(packet, out, q) @@ -66,10 +64,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, packet, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) - } else { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { @@ -81,6 +78,26 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } } +func (f *Interface) getHostinfo(packet []byte, fwPacket *firewall.Packet) (*HostInfo, bool) { + if f.firewall.ShouldUnSNAT(fwPacket) { + //unsnat packet re-writing also happens here, would be nice to not, + //but we need to do the unsnat lookup to find the hostinfo so we can run the firewall checks + destVpnAddr := f.firewall.unSnat(packet, fwPacket) + if destVpnAddr.IsValid() { + //because this was a snatted packet, we know it has an on-overlay destination, so no routing should be required. + return f.getOrHandshakeNoRouting(destVpnAddr, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + } else { + return nil, false + } + } else { //if we didn't need to unsnat + return f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + } +} + func (f *Interface) rejectInside(packet []byte, out []byte, q int) { if !f.firewall.InSendReject { return @@ -218,7 +235,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) + dropReason := f.firewall.Drop(*fp, nil, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). diff --git a/interface.go b/interface.go index 61b1f228..1377bc1b 100644 --- a/interface.go +++ b/interface.go @@ -56,6 +56,7 @@ type Interface struct { inside overlay.Device pki *PKI firewall *Firewall + snatAddr netip.Addr connectionManager *connectionManager handshakeManager *HandshakeManager serveDns bool @@ -339,7 +340,7 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) + fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c, f.firewall.snatAddr) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return diff --git a/main.go b/main.go index 74979417..e924ce8a 100644 --- a/main.go +++ b/main.go @@ -66,7 +66,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - fw, err := NewFirewallFromConfig(l, pki.getCertState(), c) + snatAddr := netip.MustParseAddr("169.254.55.96") //todo get this from tun! + fw, err := NewFirewallFromConfig(l, pki.getCertState(), c, snatAddr) if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } @@ -131,7 +132,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg deviceFactory = overlay.NewDeviceFromConfig } - tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines) + cs := pki.getCertState() + tun, err = deviceFactory(c, l, cs.myVpnNetworks, cs.GetDefaultCertificate().UnsafeNetworks(), routines) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } diff --git a/outside.go b/outside.go index b2cbf123..e78b1cb7 100644 --- a/outside.go +++ b/outside.go @@ -514,7 +514,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, out, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in diff --git a/overlay/tun.go b/overlay/tun.go index e0bf69f6..56c7fccf 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -22,22 +22,22 @@ func (e *NameError) Error() string { } // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) return tun, nil default: - return newTun(c, l, vpnNetworks, routines > 1) + return newTun(c, l, vpnNetworks, unsafeNetworks, routines > 1) } } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { - return newTunFromFd(c, l, *fd, vpnNetworks) + return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) { + return newTunFromFd(c, l, *fd, vpnNetworks, unsafeNetworks) } } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index eddef882..f091772a 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -26,7 +26,7 @@ type tun struct { l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 128c2001..c9c3927e 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -79,7 +79,7 @@ type ifreqAlias6 struct { Lifetime addrLifetime } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 2f65b3a4..c66e45f1 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -199,11 +199,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var fd int var err error diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 0ce01df8..85466d1e 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -28,11 +28,11 @@ type tun struct { l *logrus.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ vpnNetworks: vpnNetworks, diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 7e4aa418..a0a70870 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -26,14 +26,15 @@ import ( type tun struct { io.ReadWriteCloser - fd int - Device string - vpnNetworks []netip.Prefix - MaxMTU int - DefaultMTU int - TXQueueLen int - deviceIndex int - ioctlFd uintptr + fd int + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + MaxMTU int + DefaultMTU int + TXQueueLen int + deviceIndex int + ioctlFd uintptr Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] @@ -71,10 +72,10 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks) if err != nil { return nil, err } @@ -84,7 +85,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -123,7 +124,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks) if err != nil { return nil, err } @@ -133,11 +134,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), @@ -427,6 +429,27 @@ func (t *tun) setMTU() { } } +func (t *tun) setSnatRoute() error { + snataddr := netip.MustParsePrefix("169.254.55.96/32") //todo get this from elsewhere? Or maybe we should pick it, and feed it back out to the firewall? + dr := &net.IPNet{ + IP: snataddr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(snataddr.Bits(), snataddr.Addr().BitLen()), + } + + nr := netlink.Route{ + LinkIndex: t.deviceIndex, + Dst: dr, + //todo do we need these other options? + //MTU: t.DefaultMTU, + //AdvMSS: t.advMSS(Route{}), + Scope: unix.RT_SCOPE_LINK, + //Protocol: unix.RTPROT_KERNEL, + Table: unix.RT_TABLE_MAIN, + Type: unix.RTN_UNICAST, + } + return netlink.RouteReplace(&nr) +} + func (t *tun) setDefaultRoute(cidr netip.Prefix) error { dr := &net.IPNet{ IP: cidr.Masked().Addr().AsSlice(), @@ -503,6 +526,18 @@ func (t *tun) addRoutes(logErrors bool) error { } } + onlyV6Addresses := false + for _, n := range t.vpnNetworks { + if n.Addr().Is6() { + onlyV6Addresses = true + break + } + } + + if len(t.unsafeNetworks) != 0 && onlyV6Addresses { + return t.setSnatRoute() + } + return nil } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 2986c895..39336108 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -70,11 +70,11 @@ type tun struct { var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 9209b795..701d97dd 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -63,11 +63,11 @@ type tun struct { var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in openbsd") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3477de3d..145eccb9 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -28,7 +28,7 @@ type TestTun struct { TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*TestTun, error) { _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err @@ -49,7 +49,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 223eabee..d434d66d 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -38,11 +38,11 @@ type winTun struct { tun *wintun.NativeTun } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*winTun, error) { err := checkWinTunExists() if err != nil { return nil, fmt.Errorf("can not load the wintun driver: %w", err) diff --git a/overlay/user.go b/overlay/user.go index 1f92d4e9..52fa0df7 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -9,7 +9,7 @@ import ( "github.com/slackhq/nebula/routing" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) { return NewUserDevice(vpnNetworks) } diff --git a/snat.go b/snat.go new file mode 100644 index 00000000..3164b641 --- /dev/null +++ b/snat.go @@ -0,0 +1,91 @@ +package nebula + +import ( + "encoding/binary" + "net/netip" +) + +func recalcIPv4Checksum(data []byte, oldSrcIP netip.Addr, newSrcIP netip.Addr) { + oldChecksum := binary.BigEndian.Uint16(data[10:12]) + //because of how checksums work, we can re-use this function + checksum := calcNewTransportChecksum(oldChecksum, oldSrcIP, 0, newSrcIP, 0) + binary.BigEndian.PutUint16(data[10:12], checksum) +} + +func calcNewTransportChecksum(oldChecksum uint16, oldSrcIP netip.Addr, oldSrcPort uint16, newSrcIP netip.Addr, newSrcPort uint16) uint16 { + oldIP := binary.BigEndian.Uint32(oldSrcIP.AsSlice()) + newIP := binary.BigEndian.Uint32(newSrcIP.AsSlice()) + + // Start with inverted checksum + checksum := uint32(^oldChecksum) + + // Subtract old IP (as two 16-bit words) + checksum += uint32(^uint16(oldIP >> 16)) + checksum += uint32(^uint16(oldIP & 0xFFFF)) + + // Subtract old port + checksum += uint32(^oldSrcPort) + + // Add new IP (as two 16-bit words) + checksum += uint32(newIP >> 16) + checksum += uint32(newIP & 0xFFFF) + + // Add new port + checksum += uint32(newSrcPort) + + // Fold carries + for checksum > 0xFFFF { + checksum = (checksum & 0xFFFF) + (checksum >> 16) + } + + // Return ones' complement + return ^uint16(checksum) +} + +func recalcV4TransportChecksum(offsetInsideHeader int, data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) { + ipHeaderOffset := int(data[0]&0x0F) * 4 + offset := ipHeaderOffset + offsetInsideHeader + oldcsum := binary.BigEndian.Uint16(data[offset : offset+2]) + checksum := calcNewTransportChecksum(oldcsum, oldSrcIP.Addr(), oldSrcIP.Port(), newSrcIP.Addr(), newSrcIP.Port()) + binary.BigEndian.PutUint16(data[offset:offset+2], checksum) +} + +func recalcUDPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) { + const offsetInsideHeader = 6 + recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP) +} + +func recalcTCPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) { + const offsetInsideHeader = 16 + recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP) +} + +func calcNewICMPChecksum(oldChecksum uint16, oldCode uint16, newCode uint16, oldID uint16, newID uint16) uint16 { + // Start with inverted checksum + checksum := uint32(^oldChecksum) + + // Subtract old stuff + checksum += uint32(^oldCode) + checksum += uint32(^oldID) + + // Add new stuff + checksum += uint32(newCode) + checksum += uint32(newID) + + // Fold carries + for checksum > 0xFFFF { + checksum = (checksum & 0xFFFF) + (checksum >> 16) + } + + // Return ones' complement + return ^uint16(checksum) +} + +func recalcICMPv4Checksum(data []byte, oldCode uint16, newCode uint16, oldID uint16, newID uint16) { + const offsetInsideHeader = 2 + ipHeaderOffset := int(data[0]&0x0F) * 4 + offset := ipHeaderOffset + offsetInsideHeader + oldChecksum := binary.BigEndian.Uint16(data[offset : offset+2]) + checksum := calcNewICMPChecksum(oldChecksum, oldCode, newCode, oldID, newID) + binary.BigEndian.PutUint16(data[offset:offset+2], checksum) +} From 83744a106d3877d0454d7a28f860b8b4635dfce7 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 13 Feb 2026 12:55:32 -0600 Subject: [PATCH 02/31] checkpt --- overlay/tun_linux.go | 94 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 77 insertions(+), 17 deletions(-) diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index a0a70870..4ba815e9 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,6 +4,7 @@ package overlay import ( + "crypto/rand" "fmt" "io" "net" @@ -47,6 +48,8 @@ type tun struct { routesFromSystem map[netip.Prefix]routing.Gateways routesFromSystemLock sync.Mutex + snatAddr netip.Prefix + l *logrus.Logger } @@ -162,6 +165,59 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n return t, nil } +func (t *tun) prepareSnatAddr(c *config.C, initial bool, routes []Route) netip.Prefix { + if !initial { + return netip.Prefix{} //I don't wanna think about reloading this yet + } + if !t.vpnNetworks[0].Addr().Is6() { + return netip.Prefix{} //if we have an IPv4 assignment within the overlay, we don't need a snat address + } + + addSnatAddr := false + for _, un := range t.unsafeNetworks { //if we are an unsafe router for an IPv4 range + if un.Addr().Is4() { + addSnatAddr = true + break + } + } + for _, route := range routes { //or if we have a route defined into an IPv4 range + if route.Cidr.Addr().Is4() { + addSnatAddr = true //todo should this only apply to unsafe routes? + break + } + } + if !addSnatAddr { + return netip.Prefix{} + } + + var err error + out := netip.Addr{} + if a := c.GetString("tun.snat_address_for_4over6", ""); a != "" { + out, err = netip.ParseAddr(a) + if err != nil { + t.l.WithField("value", a).WithError(err).Warn("failed to parse tun.snat_address_for_4over6, will use a random value") + } else if !out.Is4() || !out.IsLinkLocalUnicast() { + t.l.WithField("value", t.snatAddr).Warn("tun.snat_address_for_4over6 must be an IPv4 address") + } + } + if !out.IsValid() { + octets := []byte{169, 254, 0, 0} + _, _ = rand.Read(octets[2:4]) + if octets[3] == 0 { + octets[3] = 1 //please no .0 addresses + } else if octets[2] == 255 && octets[3] == 255 { + octets[3] = 254 //please no broadcast addresses + } + ok := false + out, ok = netip.AddrFromSlice(octets) + if !ok { + t.l.Error("failed to produce a valid IPv4 address for tun.snat_address_for_4over6") + return netip.Prefix{} + } + } + return netip.PrefixFrom(out, 32) +} + func (t *tun) reload(c *config.C, initial bool) error { routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { @@ -172,6 +228,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + t.snatAddr = t.prepareSnatAddr(c, initial, routes) + routeTree, err := makeRouteTree(t.l, routes, true) if err != nil { return err @@ -314,6 +372,16 @@ func (t *tun) addIPs(link netlink.Link) error { Label: t.vpnNetworks[i].Addr().Zone(), } } + if t.snatAddr.IsValid() { + newAddrs = append(newAddrs, &netlink.Addr{ + IPNet: &net.IPNet{ + IP: t.snatAddr.Addr().AsSlice(), + Mask: net.CIDRMask(t.snatAddr.Bits(), t.snatAddr.Addr().BitLen()), + }, + Label: t.snatAddr.Addr().Zone(), + }) + t.l.WithField("address", t.snatAddr).Info("Adding SNAT address") + } //add all new addresses for i := range newAddrs { @@ -402,7 +470,12 @@ func (t *tun) Activate() error { //set route MTU for i := range t.vpnNetworks { if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil { - return fmt.Errorf("failed to set default route MTU: %w", err) + return fmt.Errorf("failed to set default route MTU for %s: %w", t.vpnNetworks[i], err) + } + } + if t.snatAddr.IsValid() { + if err = t.setDefaultRoute(t.snatAddr); err != nil { + return fmt.Errorf("failed to set default route MTU for %s: %w", t.snatAddr, err) } } @@ -430,10 +503,9 @@ func (t *tun) setMTU() { } func (t *tun) setSnatRoute() error { - snataddr := netip.MustParsePrefix("169.254.55.96/32") //todo get this from elsewhere? Or maybe we should pick it, and feed it back out to the firewall? dr := &net.IPNet{ - IP: snataddr.Masked().Addr().AsSlice(), - Mask: net.CIDRMask(snataddr.Bits(), snataddr.Addr().BitLen()), + IP: t.snatAddr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(t.snatAddr.Bits(), t.snatAddr.Addr().BitLen()), } nr := netlink.Route{ @@ -526,19 +598,7 @@ func (t *tun) addRoutes(logErrors bool) error { } } - onlyV6Addresses := false - for _, n := range t.vpnNetworks { - if n.Addr().Is6() { - onlyV6Addresses = true - break - } - } - - if len(t.unsafeNetworks) != 0 && onlyV6Addresses { - return t.setSnatRoute() - } - - return nil + return t.setSnatRoute() } func (t *tun) removeRoutes(routes []Route) { From 1cc257f99741bfce5132843c7f660c7be654688a Mon Sep 17 00:00:00 2001 From: JackDoan Date: Tue, 17 Feb 2026 13:09:31 -0600 Subject: [PATCH 03/31] bolt more stuff onto tun to help auto-assign snat addresses --- firewall.go | 18 +++++++++-- interface.go | 4 ++- main.go | 3 +- overlay/device.go | 2 ++ overlay/tun.go | 51 +++++++++++++++++++++++++++++++ overlay/tun_android.go | 18 +++++++---- overlay/tun_darwin.go | 25 +++++++++++----- overlay/tun_disabled.go | 7 +++++ overlay/tun_freebsd.go | 37 +++++++++++++++-------- overlay/tun_ios.go | 20 +++++++++---- overlay/tun_linux.go | 66 +++++++---------------------------------- overlay/tun_netbsd.go | 26 +++++++++++----- overlay/tun_openbsd.go | 39 +++++++++++++++--------- overlay/tun_tester.go | 34 +++++++++++++-------- overlay/tun_windows.go | 35 +++++++++++++++------- overlay/user.go | 8 +++++ test/tun.go | 10 +++++++ 17 files changed, 267 insertions(+), 136 deletions(-) diff --git a/firewall.go b/firewall.go index 2d8cf49e..58dea318 100644 --- a/firewall.go +++ b/firewall.go @@ -215,7 +215,6 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D routableNetworks: routableNetworks, assignedNetworks: assignedNetworks, hasUnsafeNetworks: hasUnsafeNetworks, - snatAddr: snatAddr, l: l, incomingMetrics: firewallMetrics{ @@ -231,7 +230,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } } -func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C, snatAddr netip.Addr) (*Firewall, error) { +func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { certificate := cs.getCertificate(cert.Version2) if certificate == nil { certificate = cs.getCertificate(cert.Version1) @@ -241,7 +240,14 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C, snatAdd panic("No certificate available to reconfigure the firewall") } - fw := NewFirewall(l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), certificate, snatAddr) + fw := NewFirewall( + l, + c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), + c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), + c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), + certificate, + netip.Addr{}, + ) fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) @@ -347,6 +353,12 @@ func (f *Firewall) GetRuleHashes() string { return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) } +func (f *Firewall) SetSNATAddressFromInterface(i *Interface) { + //address-mutation-avoidance is done inside Interface, the firewall doesn't need to care + //todo should snatted conntracks get expired out? Probably not needed until if/when we allow reload + f.snatAddr = i.inside.SNATAddress().Addr() +} + func (f *Firewall) ShouldUnSNAT(fp *firewall.Packet) bool { // f.snatAddr is only valid if we're a snat-capable router return f.snatAddr.IsValid() && fp.RemoteAddr == f.snatAddr diff --git a/interface.go b/interface.go index 1377bc1b..83d313b5 100644 --- a/interface.go +++ b/interface.go @@ -249,6 +249,7 @@ func (f *Interface) activate() { f.inside.Close() f.l.Fatal(err) } + f.firewall.SetSNATAddressFromInterface(f) } func (f *Interface) run() { @@ -340,11 +341,12 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c, f.firewall.snatAddr) + fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return } + fw.SetSNATAddressFromInterface(f) oldFw := f.firewall conntrack := oldFw.Conntrack diff --git a/main.go b/main.go index e924ce8a..975bdebf 100644 --- a/main.go +++ b/main.go @@ -66,8 +66,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - snatAddr := netip.MustParseAddr("169.254.55.96") //todo get this from tun! - fw, err := NewFirewallFromConfig(l, pki.getCertState(), c, snatAddr) + fw, err := NewFirewallFromConfig(l, pki.getCertState(), c) if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } diff --git a/overlay/device.go b/overlay/device.go index b6077aba..bb14a76c 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -11,6 +11,8 @@ type Device interface { io.ReadWriteCloser Activate() error Networks() []netip.Prefix + UnsafeNetworks() []netip.Prefix + SNATAddress() netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways SupportsMultiqueue() bool diff --git a/overlay/tun.go b/overlay/tun.go index 56c7fccf..8bac6502 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -1,6 +1,7 @@ package overlay import ( + "crypto/rand" "fmt" "net" "net/netip" @@ -129,3 +130,53 @@ func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, er return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest) } + +func prepareSnatAddr(d Device, l *logrus.Logger, c *config.C, routes []Route) netip.Prefix { + if !d.Networks()[0].Addr().Is6() { + return netip.Prefix{} //if we have an IPv4 assignment within the overlay, we don't need a snat address + } + + addSnatAddr := false + for _, un := range d.UnsafeNetworks() { //if we are an unsafe router for an IPv4 range + if un.Addr().Is4() { + addSnatAddr = true + break + } + } + for _, route := range routes { //or if we have a route defined into an IPv4 range + if route.Cidr.Addr().Is4() { + addSnatAddr = true //todo should this only apply to unsafe routes? + break + } + } + if !addSnatAddr { + return netip.Prefix{} + } + + var err error + out := netip.Addr{} + if a := c.GetString("tun.snat_address_for_4over6", ""); a != "" { + out, err = netip.ParseAddr(a) + if err != nil { + l.WithField("value", a).WithError(err).Warn("failed to parse tun.snat_address_for_4over6, will use a random value") + } else if !out.Is4() || !out.IsLinkLocalUnicast() { + l.WithField("value", out).Warn("tun.snat_address_for_4over6 must be an IPv4 address") + } + } + if !out.IsValid() { + octets := []byte{169, 254, 0, 0} + _, _ = rand.Read(octets[2:4]) + if octets[3] == 0 { + octets[3] = 1 //please no .0 addresses + } else if octets[2] == 255 && octets[3] == 255 { + octets[3] = 254 //please no broadcast addresses + } + ok := false + out, ok = netip.AddrFromSlice(octets) + if !ok { + l.Error("failed to produce a valid IPv4 address for tun.snat_address_for_4over6") + return netip.Prefix{} + } + } + return netip.PrefixFrom(out, 32) +} diff --git a/overlay/tun_android.go b/overlay/tun_android.go index f091772a..3ab3f8a7 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -19,14 +19,15 @@ import ( type tun struct { io.ReadWriteCloser - fd int - vpnNetworks []netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + fd int + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -35,6 +36,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net ReadWriteCloser: file, fd: deviceFd, vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, l: l, } @@ -91,6 +93,10 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.UnsafeNetworks() +} + func (t *tun) Name() string { return "android" } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index c9c3927e..23be219f 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -24,13 +24,15 @@ import ( type tun struct { io.ReadWriteCloser - Device string - vpnNetworks []netip.Prefix - DefaultMTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - linkAddr *netroute.LinkAddr - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + snatAddr netip.Prefix + DefaultMTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + linkAddr *netroute.LinkAddr + l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte @@ -127,6 +129,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNet ReadWriteCloser: os.NewFile(uintptr(fd), ""), Device: name, vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -545,6 +548,14 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *tun) SNATAddress() netip.Prefix { + return t.snatAddr +} + func (t *tun) Name() string { return t.Device } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index aa3dddaf..db976d10 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -22,6 +22,13 @@ type disabledTun struct { l *logrus.Logger } +func (*disabledTun) UnsafeNetworks() []netip.Prefix { + return nil +} +func (*disabledTun) SNATAddress() netip.Prefix { + return netip.Prefix{} +} + func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ vpnNetworks: vpnNetworks, diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index c66e45f1..bae53235 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -86,14 +86,16 @@ type ifreqAlias6 struct { } type tun struct { - Device string - vpnNetworks []netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - linkAddr *netroute.LinkAddr - l *logrus.Logger - devFd int + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + snatAddr netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + linkAddr *netroute.LinkAddr + l *logrus.Logger + devFd int } func (t *tun) Read(to []byte) (int, error) { @@ -270,11 +272,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNet } t := &tun{ - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, - devFd: fd, + Device: deviceName, + vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + devFd: fd, } err = t.reload(c, true) @@ -446,6 +449,14 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *tun) SNATAddress() netip.Prefix { + return t.snatAddr +} + func (t *tun) Name() string { return t.Device } diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 85466d1e..963e49c2 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -22,20 +22,22 @@ import ( type tun struct { io.ReadWriteCloser - vpnNetworks []netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger } func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, ReadWriteCloser: &tunReadCloser{f: file}, l: l, } @@ -147,6 +149,14 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *tun) SNATAddress() netip.Prefix { + return t.snatAddr +} + func (t *tun) Name() string { return "iOS" } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 4ba815e9..1b604c2a 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -4,7 +4,6 @@ package overlay import ( - "crypto/rand" "fmt" "io" "net" @@ -57,6 +56,14 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *tun) SNATAddress() netip.Prefix { + return t.snatAddr +} + type ifReq struct { Name [16]byte Flags uint16 @@ -165,59 +172,6 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []n return t, nil } -func (t *tun) prepareSnatAddr(c *config.C, initial bool, routes []Route) netip.Prefix { - if !initial { - return netip.Prefix{} //I don't wanna think about reloading this yet - } - if !t.vpnNetworks[0].Addr().Is6() { - return netip.Prefix{} //if we have an IPv4 assignment within the overlay, we don't need a snat address - } - - addSnatAddr := false - for _, un := range t.unsafeNetworks { //if we are an unsafe router for an IPv4 range - if un.Addr().Is4() { - addSnatAddr = true - break - } - } - for _, route := range routes { //or if we have a route defined into an IPv4 range - if route.Cidr.Addr().Is4() { - addSnatAddr = true //todo should this only apply to unsafe routes? - break - } - } - if !addSnatAddr { - return netip.Prefix{} - } - - var err error - out := netip.Addr{} - if a := c.GetString("tun.snat_address_for_4over6", ""); a != "" { - out, err = netip.ParseAddr(a) - if err != nil { - t.l.WithField("value", a).WithError(err).Warn("failed to parse tun.snat_address_for_4over6, will use a random value") - } else if !out.Is4() || !out.IsLinkLocalUnicast() { - t.l.WithField("value", t.snatAddr).Warn("tun.snat_address_for_4over6 must be an IPv4 address") - } - } - if !out.IsValid() { - octets := []byte{169, 254, 0, 0} - _, _ = rand.Read(octets[2:4]) - if octets[3] == 0 { - octets[3] = 1 //please no .0 addresses - } else if octets[2] == 255 && octets[3] == 255 { - octets[3] = 254 //please no broadcast addresses - } - ok := false - out, ok = netip.AddrFromSlice(octets) - if !ok { - t.l.Error("failed to produce a valid IPv4 address for tun.snat_address_for_4over6") - return netip.Prefix{} - } - } - return netip.PrefixFrom(out, 32) -} - func (t *tun) reload(c *config.C, initial bool) error { routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { @@ -228,7 +182,9 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } - t.snatAddr = t.prepareSnatAddr(c, initial, routes) + if !initial { + t.snatAddr = prepareSnatAddr(t, t.l, c, routes) + } routeTree, err := makeRouteTree(t.l, routes, true) if err != nil { diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 39336108..5174adb3 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -58,14 +58,16 @@ type addrLifetime struct { } type tun struct { - Device string - vpnNetworks []netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger - f *os.File - fd int + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + snatAddr netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger + f *os.File + fd int } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) @@ -386,6 +388,14 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *tun) SNATAddress() netip.Prefix { + return t.snatAddr +} + func (t *tun) Name() string { return t.Device } diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 701d97dd..9f0d6567 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -49,14 +49,16 @@ type ifreq struct { } type tun struct { - Device string - vpnNetworks []netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger - f *os.File - fd int + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + snatAddr netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger + f *os.File + fd int // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte } @@ -89,12 +91,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNet } t := &tun{ - f: os.NewFile(uintptr(fd), ""), - fd: fd, - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + f: os.NewFile(uintptr(fd), ""), + fd: fd, + Device: deviceName, + vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, } err = t.reload(c, true) @@ -306,6 +309,14 @@ func (t *tun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *tun) SNATAddress() netip.Prefix { + return t.snatAddr +} + func (t *tun) Name() string { return t.Device } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 145eccb9..3e876cb5 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -17,11 +17,12 @@ import ( ) type TestTun struct { - Device string - vpnNetworks []netip.Prefix - Routes []Route - routeTree *bart.Table[routing.Gateways] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + Routes []Route + routeTree *bart.Table[routing.Gateways] + l *logrus.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula @@ -39,13 +40,14 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNet } return &TestTun{ - Device: c.GetString("tun.dev", ""), - vpnNetworks: vpnNetworks, - Routes: routes, - routeTree: routeTree, - l: l, - rxPackets: make(chan []byte, 10), - TxPackets: make(chan []byte, 10), + Device: c.GetString("tun.dev", ""), + vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, + Routes: routes, + routeTree: routeTree, + l: l, + rxPackets: make(chan []byte, 10), + TxPackets: make(chan []byte, 10), }, nil } @@ -139,3 +141,11 @@ func (t *TestTun) SupportsMultiqueue() bool { func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented") } + +func (t *tun) UnsafeNetworks() []netip.Prefix { + return t.UnsafeNetworks() +} + +func (t *tun) SNATAddress() netip.Prefix { + return netip.Prefix{} +} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index d434d66d..cab52d99 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -28,12 +28,14 @@ import ( const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { - Device string - vpnNetworks []netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + snatAddr netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger tun *wintun.NativeTun } @@ -55,10 +57,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNet } t := &winTun{ - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + Device: deviceName, + vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, } err = t.reload(c, true) @@ -102,6 +105,10 @@ func (t *winTun) reload(c *config.C, initial bool) error { return nil } + if !initial { + t.snatAddr = prepareSnatAddr(t, t.l, c, routes) + } + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err @@ -225,6 +232,14 @@ func (t *winTun) Networks() []netip.Prefix { return t.vpnNetworks } +func (t *winTun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks +} + +func (t *winTun) SNATAddress() netip.Prefix { + return t.snatAddr +} + func (t *winTun) Name() string { return t.Device } diff --git a/overlay/user.go b/overlay/user.go index 52fa0df7..1c01dd1c 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -36,6 +36,14 @@ type UserDevice struct { inboundWriter *io.PipeWriter } +func (d *UserDevice) UnsafeNetworks() []netip.Prefix { + return nil +} + +func (d *UserDevice) SNATAddress() netip.Prefix { + return netip.Prefix{} +} + func (d *UserDevice) Activate() error { return nil } diff --git a/test/tun.go b/test/tun.go index fb32782f..182ee88d 100644 --- a/test/tun.go +++ b/test/tun.go @@ -10,6 +10,16 @@ import ( type NoopTun struct{} +func (NoopTun) Routes() []Route { + //TODO implement me + panic("implement me") +} + +func (NoopTun) UnsafeNetworks() []netip.Prefix { + //TODO implement me + panic("implement me") +} + func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { return routing.Gateways{} } From 27d764ba57908564071adb402ca2328ee27e5952 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Tue, 17 Feb 2026 13:50:20 -0600 Subject: [PATCH 04/31] auto-assign snataddr on Mac+Windows --- overlay/tun_darwin.go | 9 +++++++++ overlay/tun_windows.go | 7 ++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 23be219f..651f4b2d 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -216,6 +216,11 @@ func (t *tun) Activate() error { } } } + if t.snatAddr.IsValid() && t.snatAddr.Addr().Is4() { + if err = t.activate4(t.snatAddr); err != nil { + return err + } + } // Run the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING @@ -317,6 +322,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + if !initial { + t.snatAddr = prepareSnatAddr(t, t.l, c, routes) + } + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index cab52d99..248db47c 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -139,7 +139,12 @@ func (t *winTun) reload(c *config.C, initial bool) error { func (t *winTun) Activate() error { luid := winipcfg.LUID(t.tun.LUID()) - err := luid.SetIPAddresses(t.vpnNetworks) + prefixes := t.vpnNetworks + if t.snatAddr.IsValid() { + prefixes = append(prefixes, t.snatAddr) + } + + err := luid.SetIPAddresses(prefixes) if err != nil { return fmt.Errorf("failed to set address: %w", err) } From 7498c6846d4eeefddd7d17843aa0da2c286986c2 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Tue, 17 Feb 2026 15:00:00 -0600 Subject: [PATCH 05/31] checkpt --- overlay/tun_darwin.go | 2 +- overlay/tun_freebsd.go | 4 ++++ overlay/tun_linux.go | 16 +++++++++------- overlay/tun_netbsd.go | 4 ++++ overlay/tun_openbsd.go | 4 ++++ overlay/tun_windows.go | 2 +- 6 files changed, 23 insertions(+), 9 deletions(-) diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 651f4b2d..0ab331bb 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -322,7 +322,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } - if !initial { + if initial { t.snatAddr = prepareSnatAddr(t, t.l, c, routes) } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index bae53235..31289d55 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -413,6 +413,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + if initial { + t.snatAddr = prepareSnatAddr(t, t.l, c, routes) + } + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 1b604c2a..0f3bec96 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -182,7 +182,7 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } - if !initial { + if initial { t.snatAddr = prepareSnatAddr(t, t.l, c, routes) } @@ -328,7 +328,8 @@ func (t *tun) addIPs(link netlink.Link) error { Label: t.vpnNetworks[i].Addr().Zone(), } } - if t.snatAddr.IsValid() { + + if t.snatAddr.IsValid() && len(t.vpnNetworks) > 0 { //TODO unsafe-routers should be able to snat and be snatted newAddrs = append(newAddrs, &netlink.Addr{ IPNet: &net.IPNet{ IP: t.snatAddr.Addr().AsSlice(), @@ -429,11 +430,12 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to set default route MTU for %s: %w", t.vpnNetworks[i], err) } } - if t.snatAddr.IsValid() { - if err = t.setDefaultRoute(t.snatAddr); err != nil { - return fmt.Errorf("failed to set default route MTU for %s: %w", t.snatAddr, err) - } - } + //TODO snat and be snatted + //if t.snatAddr.IsValid() { + // if err = t.setDefaultRoute(t.snatAddr); err != nil { + // return fmt.Errorf("failed to set default route MTU for %s: %w", t.snatAddr, err) + // } + //} // Set the routes if err = t.addRoutes(false); err != nil { diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 5174adb3..e81e466c 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -352,6 +352,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + if initial { + t.snatAddr = prepareSnatAddr(t, t.l, c, routes) + } + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 9f0d6567..e88bd0f4 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -273,6 +273,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + if initial { + t.snatAddr = prepareSnatAddr(t, t.l, c, routes) + } + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 248db47c..4f8bb5b9 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -105,7 +105,7 @@ func (t *winTun) reload(c *config.C, initial bool) error { return nil } - if !initial { + if initial { t.snatAddr = prepareSnatAddr(t, t.l, c, routes) } From 37abdd7f96fe7fdc2015bc2b51aef5cca00bd89c Mon Sep 17 00:00:00 2001 From: JackDoan Date: Tue, 17 Feb 2026 15:15:10 -0600 Subject: [PATCH 06/31] it works again but linux is pickier than I thought, I need to refactor even more --- firewall.go | 4 +++- overlay/tun_linux.go | 23 +++++++++++++++++------ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/firewall.go b/firewall.go index 58dea318..f5137946 100644 --- a/firewall.go +++ b/firewall.go @@ -356,7 +356,9 @@ func (f *Firewall) GetRuleHashes() string { func (f *Firewall) SetSNATAddressFromInterface(i *Interface) { //address-mutation-avoidance is done inside Interface, the firewall doesn't need to care //todo should snatted conntracks get expired out? Probably not needed until if/when we allow reload - f.snatAddr = i.inside.SNATAddress().Addr() + if f.hasUnsafeNetworks { //todo this logic??? + f.snatAddr = i.inside.SNATAddress().Addr() + } } func (f *Firewall) ShouldUnSNAT(fp *firewall.Packet) bool { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 0f3bec96..0382c3e2 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -329,7 +329,7 @@ func (t *tun) addIPs(link netlink.Link) error { } } - if t.snatAddr.IsValid() && len(t.vpnNetworks) > 0 { //TODO unsafe-routers should be able to snat and be snatted + if t.snatAddr.IsValid() && len(t.unsafeNetworks) == 0 { //TODO unsafe-routers should be able to snat and be snatted newAddrs = append(newAddrs, &netlink.Addr{ IPNet: &net.IPNet{ IP: t.snatAddr.Addr().AsSlice(), @@ -431,11 +431,11 @@ func (t *tun) Activate() error { } } //TODO snat and be snatted - //if t.snatAddr.IsValid() { - // if err = t.setDefaultRoute(t.snatAddr); err != nil { - // return fmt.Errorf("failed to set default route MTU for %s: %w", t.snatAddr, err) - // } - //} + if t.snatAddr.IsValid() && len(t.unsafeNetworks) == 0 { + if err = t.setDefaultRoute(t.snatAddr); err != nil { + return fmt.Errorf("failed to set default route MTU for %s: %w", t.snatAddr, err) + } + } // Set the routes if err = t.addRoutes(false); err != nil { @@ -448,6 +448,14 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to run tun device: %s", err) } + //todo hmmmmmm + if len(t.unsafeNetworks) != 0 { + err = os.WriteFile(fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/accept_local", t.Device), []byte("1"), os.FileMode(0o644)) + if err != nil { + return err + } + } + return nil } @@ -556,6 +564,9 @@ func (t *tun) addRoutes(logErrors bool) error { } } + if len(t.unsafeNetworks) == 0 { + return nil + } return t.setSnatRoute() } From 92ee45ed137d8f5fd46f14529bde4888ed37c497 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Tue, 17 Feb 2026 15:16:36 -0600 Subject: [PATCH 07/31] tun tester more useful --- firewall_test.go | 18 +++++++++--------- overlay/tun_linux.go | 13 +++++++------ overlay/tun_tester.go | 15 +++++++++------ test/tun.go | 10 ++++------ 4 files changed, 29 insertions(+), 27 deletions(-) diff --git a/firewall_test.go b/firewall_test.go index 282929f2..c42cad65 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -1047,53 +1047,53 @@ func TestNewFirewallFromConfig(t *testing.T) { conf := config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": "asdf"} - _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) + _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}} - _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) + _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}} - _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) + _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) + _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) + _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) + _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) + _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) + _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} - _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) + _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 0382c3e2..0569fcd8 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -449,12 +449,13 @@ func (t *tun) Activate() error { } //todo hmmmmmm - if len(t.unsafeNetworks) != 0 { - err = os.WriteFile(fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/accept_local", t.Device), []byte("1"), os.FileMode(0o644)) - if err != nil { - return err - } - } + //pretty sure this is avoidable + //if len(t.unsafeNetworks) != 0 { + // err = os.WriteFile(fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/accept_local", t.Device), []byte("1"), os.FileMode(0o644)) + // if err != nil { + // return err + // } + //} return nil } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3e876cb5..cb96c195 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -20,6 +20,7 @@ type TestTun struct { Device string vpnNetworks []netip.Prefix unsafeNetworks []netip.Prefix + snatAddr netip.Prefix Routes []Route routeTree *bart.Table[routing.Gateways] l *logrus.Logger @@ -39,7 +40,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNet return nil, err } - return &TestTun{ + tt := &TestTun{ Device: c.GetString("tun.dev", ""), vpnNetworks: vpnNetworks, unsafeNetworks: unsafeNetworks, @@ -48,7 +49,9 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNet l: l, rxPackets: make(chan []byte, 10), TxPackets: make(chan []byte, 10), - }, nil + } + tt.snatAddr = prepareSnatAddr(tt, l, c, routes) + return tt, nil } func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*TestTun, error) { @@ -142,10 +145,10 @@ func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented") } -func (t *tun) UnsafeNetworks() []netip.Prefix { - return t.UnsafeNetworks() +func (t *TestTun) UnsafeNetworks() []netip.Prefix { + return t.unsafeNetworks } -func (t *tun) SNATAddress() netip.Prefix { - return netip.Prefix{} +func (t *TestTun) SNATAddress() netip.Prefix { + return t.snatAddr } diff --git a/test/tun.go b/test/tun.go index 182ee88d..37728f6c 100644 --- a/test/tun.go +++ b/test/tun.go @@ -10,14 +10,12 @@ import ( type NoopTun struct{} -func (NoopTun) Routes() []Route { - //TODO implement me - panic("implement me") +func (NoopTun) UnsafeNetworks() []netip.Prefix { + return nil } -func (NoopTun) UnsafeNetworks() []netip.Prefix { - //TODO implement me - panic("implement me") +func (NoopTun) SNATAddress() netip.Prefix { + return netip.Prefix{} } func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { From 25610225bbff9ef46071486e0543f8d8fa95b47d Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 18 Feb 2026 15:07:57 -0600 Subject: [PATCH 08/31] crappy AI tests --- e2e/snat_test.go | 400 ++++++++++++ overlay/tun_snat_test.go | 171 +++++ snat_test.go | 1309 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 1880 insertions(+) create mode 100644 e2e/snat_test.go create mode 100644 overlay/tun_snat_test.go create mode 100644 snat_test.go diff --git a/e2e/snat_test.go b/e2e/snat_test.go new file mode 100644 index 00000000..a0e0af96 --- /dev/null +++ b/e2e/snat_test.go @@ -0,0 +1,400 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "encoding/binary" + "net/netip" + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/e2e/router" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// parseIPv4UDPPacket extracts source/dest IPs, ports, and payload from an IPv4 UDP packet. +func parseIPv4UDPPacket(t testing.TB, pkt []byte) (srcIP, dstIP netip.Addr, srcPort, dstPort uint16, payload []byte) { + t.Helper() + require.True(t, len(pkt) >= 28, "packet too short for IPv4+UDP header") + require.Equal(t, byte(0x45), pkt[0]&0xF0|pkt[0]&0x0F, "not a simple IPv4 packet (IHL!=5)") + + srcIP, _ = netip.AddrFromSlice(pkt[12:16]) + dstIP, _ = netip.AddrFromSlice(pkt[16:20]) + + ihl := int(pkt[0]&0x0F) * 4 + require.True(t, len(pkt) >= ihl+8, "packet too short for UDP header") + srcPort = binary.BigEndian.Uint16(pkt[ihl : ihl+2]) + dstPort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4]) + udpLen := binary.BigEndian.Uint16(pkt[ihl+4 : ihl+6]) + payload = pkt[ihl+8 : ihl+int(udpLen)] + return +} + +func TestSNAT_IPv6OnlyPeer_IPv4UnsafeTraffic(t *testing.T) { + // Scenario: Two IPv6-only VPN nodes. The "router" node has unsafe networks + // (192.168.0.0/16) in its cert and a configured SNAT address. The "sender" + // node has an unsafe route for 192.168.0.0/16 via the router. + // + // When sender injects an IPv4 packet destined for the unsafe network, it + // gets tunneled to the router. The router's firewall detects this is IPv4 + // from an IPv6-only peer and applies SNAT, rewriting the source IP to the + // SNAT address before delivering it to TUN. + // + // When a reply comes back from TUN addressed to the SNAT address, the + // router un-SNATs it (restoring the original destination) and tunnels it + // back to the sender. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + unsafePrefix := "192.168.0.0/16" + snatAddr := netip.MustParseAddr("169.254.42.42") + + // Router: IPv6-only with unsafe networks and a manual SNAT address. + // Override inbound firewall with local_cidr: "any" so both IPv4 (unsafe) + // and IPv6 (VPN) traffic is accepted. + routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks( + cert.Version2, ca, caKey, "router", "ff::1/64", + netip.MustParseAddrPort("[beef::1]:4242"), + unsafePrefix, + m{ + "firewall": m{ + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "any", + }}, + }, + "tun": m{ + "snat_address_for_4over6": snatAddr.String(), + }, + }, + ) + + // Sender: IPv6-only with an unsafe route via the router + senderControl, _, _, _ := newSimpleServerWithUdp( + cert.Version2, ca, caKey, "sender", "ff::2/64", + netip.MustParseAddrPort("[beef::2]:4242"), + m{ + "tun": m{ + "unsafe_routes": []m{ + {"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()}, + }, + }, + }, + ) + + // Tell sender where the router lives + senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr) + + // Build the router and start both nodes + r := router.NewR(t, routerControl, senderControl) + defer r.RenderFlow() + + routerControl.Start() + senderControl.Start() + + // --- Outbound: sender -> IPv4 unsafe dest (via router with SNAT) --- + + origSrcIP := netip.MustParseAddr("10.0.0.1") + unsafeDest := netip.MustParseAddr("192.168.1.1") + var origSrcPort uint16 = 12345 + var dstPort uint16 = 80 + + t.Log("Sender injects an IPv4 packet to the unsafe network") + senderControl.InjectTunUDPPacket(unsafeDest, dstPort, origSrcIP, origSrcPort, []byte("snat me")) + + t.Log("Route packets (handshake + data) until the router gets the packet on TUN") + snatPkt := r.RouteForAllUntilTxTun(routerControl) + + t.Log("Verify the packet was SNATted") + gotSrcIP, gotDstIP, gotSrcPort, gotDstPort, gotPayload := parseIPv4UDPPacket(t, snatPkt) + assert.Equal(t, snatAddr, gotSrcIP, "source IP should be rewritten to the SNAT address") + assert.Equal(t, unsafeDest, gotDstIP, "destination IP should be unchanged") + assert.Equal(t, dstPort, gotDstPort, "destination port should be unchanged") + assert.Equal(t, []byte("snat me"), gotPayload, "payload should be unchanged") + + // Capture the SNAT port (may differ from original if port was remapped) + snatPort := gotSrcPort + t.Logf("SNAT port: %d (original: %d)", snatPort, origSrcPort) + + // --- Return: reply from unsafe dest -> un-SNATted back to sender --- + + t.Log("Router injects a reply packet from the unsafe dest to the SNAT address") + routerControl.InjectTunUDPPacket(snatAddr, snatPort, unsafeDest, dstPort, []byte("reply from unsafe")) + + t.Log("Route until sender gets the reply on TUN") + replyPkt := r.RouteForAllUntilTxTun(senderControl) + + t.Log("Verify the reply was un-SNATted") + replySrcIP, replyDstIP, replySrcPort, replyDstPort, replyPayload := parseIPv4UDPPacket(t, replyPkt) + assert.Equal(t, unsafeDest, replySrcIP, "reply source should be the unsafe dest") + assert.Equal(t, origSrcIP, replyDstIP, "reply dest should be the original source IP (un-SNATted)") + assert.Equal(t, dstPort, replySrcPort, "reply source port should be the unsafe dest port") + assert.Equal(t, origSrcPort, replyDstPort, "reply dest port should be the original source port (un-SNATted)") + assert.Equal(t, []byte("reply from unsafe"), replyPayload, "payload should be unchanged") + + r.RenderHostmaps("Final hostmaps", routerControl, senderControl) + + // Also verify normal IPv6 VPN traffic still works between the nodes + t.Log("Verify normal IPv6 VPN tunnel still works") + assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r) + + routerControl.Stop() + senderControl.Stop() +} + +func TestSNAT_MultipleFlows(t *testing.T) { + // Test that multiple distinct IPv4 flows from the same IPv6-only peer + // are tracked independently through SNAT. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + unsafePrefix := "192.168.0.0/16" + snatAddr := netip.MustParseAddr("169.254.42.42") + + routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks( + cert.Version2, ca, caKey, "router", "ff::1/64", + netip.MustParseAddrPort("[beef::1]:4242"), + unsafePrefix, + m{ + "firewall": m{ + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "any", + }}, + }, + "tun": m{ + "snat_address_for_4over6": snatAddr.String(), + }, + }, + ) + + senderControl, _, _, _ := newSimpleServerWithUdp( + cert.Version2, ca, caKey, "sender", "ff::2/64", + netip.MustParseAddrPort("[beef::2]:4242"), + m{ + "tun": m{ + "unsafe_routes": []m{ + {"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()}, + }, + }, + }, + ) + + senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr) + + r := router.NewR(t, routerControl, senderControl) + defer r.RenderFlow() + r.CancelFlowLogs() + + routerControl.Start() + senderControl.Start() + + unsafeDest := netip.MustParseAddr("192.168.1.1") + + // Send first flow + senderControl.InjectTunUDPPacket(unsafeDest, 80, netip.MustParseAddr("10.0.0.1"), 1111, []byte("flow1")) + pkt1 := r.RouteForAllUntilTxTun(routerControl) + srcIP1, _, srcPort1, _, payload1 := parseIPv4UDPPacket(t, pkt1) + assert.Equal(t, snatAddr, srcIP1) + assert.Equal(t, []byte("flow1"), payload1) + + // Send second flow (different source port) + senderControl.InjectTunUDPPacket(unsafeDest, 80, netip.MustParseAddr("10.0.0.1"), 2222, []byte("flow2")) + pkt2 := r.RouteForAllUntilTxTun(routerControl) + srcIP2, _, srcPort2, _, payload2 := parseIPv4UDPPacket(t, pkt2) + assert.Equal(t, snatAddr, srcIP2) + assert.Equal(t, []byte("flow2"), payload2) + + // The two flows should have different SNAT ports (since they're different conntracks) + t.Logf("Flow 1 SNAT port: %d, Flow 2 SNAT port: %d", srcPort1, srcPort2) + + // Reply to flow 2 first (out of order) + routerControl.InjectTunUDPPacket(snatAddr, srcPort2, unsafeDest, 80, []byte("reply2")) + reply2 := r.RouteForAllUntilTxTun(senderControl) + _, replyDst2, _, replyDstPort2, replyPayload2 := parseIPv4UDPPacket(t, reply2) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), replyDst2) + assert.Equal(t, uint16(2222), replyDstPort2, "reply to flow 2 should restore original port 2222") + assert.Equal(t, []byte("reply2"), replyPayload2) + + // Reply to flow 1 + routerControl.InjectTunUDPPacket(snatAddr, srcPort1, unsafeDest, 80, []byte("reply1")) + reply1 := r.RouteForAllUntilTxTun(senderControl) + _, replyDst1, _, replyDstPort1, replyPayload1 := parseIPv4UDPPacket(t, reply1) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), replyDst1) + assert.Equal(t, uint16(1111), replyDstPort1, "reply to flow 1 should restore original port 1111") + assert.Equal(t, []byte("reply1"), replyPayload1) + + routerControl.Stop() + senderControl.Stop() +} + +// --- Adversarial SNAT E2E Tests --- + +func TestSNAT_UnsolicitedReplyDropped(t *testing.T) { + // Without any outbound SNAT traffic, inject a packet from the router's TUN + // addressed to the SNAT address. The sender must never receive it because + // there's no conntrack entry to un-SNAT through. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + unsafePrefix := "192.168.0.0/16" + snatAddr := netip.MustParseAddr("169.254.42.42") + + routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks( + cert.Version2, ca, caKey, "router", "ff::1/64", + netip.MustParseAddrPort("[beef::1]:4242"), + unsafePrefix, + m{ + "firewall": m{ + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "any", + }}, + }, + "tun": m{ + "snat_address_for_4over6": snatAddr.String(), + }, + }, + ) + + senderControl, _, _, _ := newSimpleServerWithUdp( + cert.Version2, ca, caKey, "sender", "ff::2/64", + netip.MustParseAddrPort("[beef::2]:4242"), + m{ + "tun": m{ + "unsafe_routes": []m{ + {"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()}, + }, + }, + }, + ) + + senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr) + + r := router.NewR(t, routerControl, senderControl) + defer r.RenderFlow() + r.CancelFlowLogs() + + routerControl.Start() + senderControl.Start() + + // First establish the tunnel with normal IPv6 traffic so handshake completes + assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r) + + // Inject the unsolicited reply from router's TUN to the SNAT address. + // There is NO prior outbound SNAT flow, so no conntrack entry exists. + // The router should silently drop this because unSnat finds no matching conntrack. + routerControl.InjectTunUDPPacket(snatAddr, 55555, netip.MustParseAddr("192.168.1.1"), 80, []byte("unsolicited")) + + // Send a canary IPv6 VPN packet after the bad one. Since the router processes + // TUN packets sequentially, the canary arriving proves the bad packet was processed first. + senderVpnAddr := senderControl.GetVpnAddrs()[0] + routerControl.InjectTunUDPPacket(senderVpnAddr, 90, routerVpnIpNet[0].Addr(), 80, []byte("canary")) + canaryPkt := r.RouteForAllUntilTxTun(senderControl) + assertUdpPacket(t, []byte("canary"), canaryPkt, routerVpnIpNet[0].Addr(), senderVpnAddr, 80, 90) + + // The unsolicited packet should have been dropped — nothing else on sender's TUN + got := senderControl.GetFromTun(false) + assert.Nil(t, got, "sender should not receive unsolicited packet to SNAT address with no conntrack entry") + + routerControl.Stop() + senderControl.Stop() +} + +func TestSNAT_NonUnsafeDestDropped(t *testing.T) { + // An IPv6-only sender sends IPv4 traffic to a destination outside the router's + // unsafe networks (172.16.0.1 when unsafe is 192.168.0.0/16). The router should + // reject this because the local address is not routable. This verifies that + // willingToHandleLocalAddr enforces boundaries on what SNAT traffic can reach. + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + unsafePrefix := "192.168.0.0/16" + snatAddr := netip.MustParseAddr("169.254.42.42") + + routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks( + cert.Version2, ca, caKey, "router", "ff::1/64", + netip.MustParseAddrPort("[beef::1]:4242"), + unsafePrefix, + m{ + "firewall": m{ + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "any", + }}, + }, + "tun": m{ + "snat_address_for_4over6": snatAddr.String(), + }, + }, + ) + + // Sender has unsafe routes for BOTH 192.168.0.0/16 AND 172.16.0.0/12 via router. + // This means the sender will route 172.16.0.1 through the tunnel to the router. + // But the router should reject it because 172.16.0.0/12 is NOT in its unsafe networks. + senderControl, _, _, _ := newSimpleServerWithUdp( + cert.Version2, ca, caKey, "sender", "ff::2/64", + netip.MustParseAddrPort("[beef::2]:4242"), + m{ + "tun": m{ + "unsafe_routes": []m{ + {"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()}, + {"route": "172.16.0.0/12", "via": routerVpnIpNet[0].Addr().String()}, + }, + }, + }, + ) + + senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr) + + r := router.NewR(t, routerControl, senderControl) + defer r.RenderFlow() + r.CancelFlowLogs() + + routerControl.Start() + senderControl.Start() + + // Establish the tunnel first + assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r) + + // Send to 172.16.0.1 (NOT in router's unsafe networks 192.168.0.0/16). + // The router should reject this at willingToHandleLocalAddr. + senderControl.InjectTunUDPPacket( + netip.MustParseAddr("172.16.0.1"), 80, + netip.MustParseAddr("10.0.0.1"), 12345, + []byte("wrong dest"), + ) + + // Send a canary to a valid unsafe destination to prove the bad packet was processed + senderControl.InjectTunUDPPacket( + netip.MustParseAddr("192.168.1.1"), 80, + netip.MustParseAddr("10.0.0.1"), 33333, + []byte("canary"), + ) + + // Route until the canary arrives — the 172.16.0.1 packet should have been + // processed and dropped before the canary gets through + canaryPkt := r.RouteForAllUntilTxTun(routerControl) + _, canaryDst, _, _, canaryPayload := parseIPv4UDPPacket(t, canaryPkt) + assert.Equal(t, netip.MustParseAddr("192.168.1.1"), canaryDst, "canary should arrive at the valid unsafe dest") + assert.Equal(t, []byte("canary"), canaryPayload) + + // No more packets — the 172.16.0.1 packet was dropped + got := routerControl.GetFromTun(false) + assert.Nil(t, got, "packet to non-unsafe destination 172.16.0.1 should be dropped by the router") + + routerControl.Stop() + senderControl.Stop() +} diff --git a/overlay/tun_snat_test.go b/overlay/tun_snat_test.go new file mode 100644 index 00000000..0040edb4 --- /dev/null +++ b/overlay/tun_snat_test.go @@ -0,0 +1,171 @@ +package overlay + +import ( + "io" + "net/netip" + "testing" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/routing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockDevice is a minimal Device implementation for testing prepareSnatAddr. +type mockDevice struct { + networks []netip.Prefix + unsafeNetworks []netip.Prefix + snatAddr netip.Prefix +} + +func (d *mockDevice) Read([]byte) (int, error) { return 0, nil } +func (d *mockDevice) Write([]byte) (int, error) { return 0, nil } +func (d *mockDevice) Close() error { return nil } +func (d *mockDevice) Activate() error { return nil } +func (d *mockDevice) Networks() []netip.Prefix { return d.networks } +func (d *mockDevice) UnsafeNetworks() []netip.Prefix { return d.unsafeNetworks } +func (d *mockDevice) SNATAddress() netip.Prefix { return d.snatAddr } +func (d *mockDevice) Name() string { return "mock" } +func (d *mockDevice) RoutesFor(netip.Addr) routing.Gateways { return routing.Gateways{} } +func (d *mockDevice) SupportsMultiqueue() bool { return false } +func (d *mockDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, nil } + +func TestPrepareSnatAddr_V4Primary_NoSnat(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // If the device has an IPv4 primary address, no SNAT needed + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, + } + result := prepareSnatAddr(d, l, c, nil) + assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr when device has IPv4 primary") +} + +func TestPrepareSnatAddr_V6Primary_NoUnsafeOrRoutes(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // IPv6 primary but no unsafe networks or IPv4 routes + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + } + result := prepareSnatAddr(d, l, c, nil) + assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr without IPv4 unsafe networks or routes") +} + +func TestPrepareSnatAddr_V6Primary_WithV4Unsafe(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // IPv6 primary with IPv4 unsafe network -> should get SNAT addr + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + result := prepareSnatAddr(d, l, c, nil) + require.True(t, result.IsValid(), "should assign SNAT addr") + assert.True(t, result.Addr().Is4(), "SNAT addr should be IPv4") + assert.True(t, result.Addr().IsLinkLocalUnicast(), "SNAT addr should be link-local") + assert.Equal(t, 32, result.Bits(), "SNAT addr should be /32") +} + +func TestPrepareSnatAddr_V6Primary_WithV4Route(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // IPv6 primary with IPv4 route -> should get SNAT addr + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + } + routes := []Route{ + {Cidr: netip.MustParsePrefix("10.0.0.0/8")}, + } + result := prepareSnatAddr(d, l, c, routes) + require.True(t, result.IsValid(), "should assign SNAT addr when IPv4 route exists") + assert.True(t, result.Addr().Is4()) + assert.True(t, result.Addr().IsLinkLocalUnicast()) +} + +func TestPrepareSnatAddr_V6Primary_V6UnsafeOnly(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + // IPv6 primary with only IPv6 unsafe network -> no SNAT needed + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("fd01::/64")}, + } + result := prepareSnatAddr(d, l, c, nil) + assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr for IPv6-only unsafe networks") +} + +func TestPrepareSnatAddr_ManualAddress(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + c.Settings["tun"] = map[string]any{ + "snat_address_for_4over6": "169.254.42.42", + } + + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + result := prepareSnatAddr(d, l, c, nil) + require.True(t, result.IsValid()) + assert.Equal(t, netip.MustParseAddr("169.254.42.42"), result.Addr()) + assert.Equal(t, 32, result.Bits()) +} + +func TestPrepareSnatAddr_InvalidManualAddress_Fallback(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + c.Settings["tun"] = map[string]any{ + "snat_address_for_4over6": "not-an-ip", + } + + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + result := prepareSnatAddr(d, l, c, nil) + // Should fall back to auto-assignment + require.True(t, result.IsValid(), "should fall back to auto-assigned address") + assert.True(t, result.Addr().Is4()) + assert.True(t, result.Addr().IsLinkLocalUnicast()) +} + +func TestPrepareSnatAddr_AutoGenerated_Range(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + c := config.NewC(l) + + d := &mockDevice{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + + // Generate several addresses and verify they're all in the expected range + for i := 0; i < 100; i++ { + result := prepareSnatAddr(d, l, c, nil) + require.True(t, result.IsValid()) + addr := result.Addr() + octets := addr.As4() + assert.Equal(t, byte(169), octets[0], "first octet should be 169") + assert.Equal(t, byte(254), octets[1], "second octet should be 254") + // Should not have .0 in the last octet + assert.NotEqual(t, byte(0), octets[3], "last octet should not be 0") + // Should not be 169.254.255.255 (broadcast) + if octets[2] == 255 { + assert.NotEqual(t, byte(255), octets[3], "should not be broadcast address") + } + } +} diff --git a/snat_test.go b/snat_test.go new file mode 100644 index 00000000..b6e2a116 --- /dev/null +++ b/snat_test.go @@ -0,0 +1,1309 @@ +package nebula + +import ( + "encoding/binary" + "net/netip" + "slices" + "testing" + "time" + + "github.com/gaissmai/bart" + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/firewall" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Canonical test packets with all checksums computed from scratch by +// /tmp/gen_canonical.go. Tests feed these into the production rewrite +// functions and compare byte-for-byte against expected outputs. + +// canonicalUDP: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="hello world" +var canonicalUDP = []byte{ + 0x45, 0x00, 0x00, 0x27, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xe8, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x13, 0x71, 0xc6, 0x68, 0x65, 0x6c, 0x6c, + 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, +} + +// canonicalTCP: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=TCP payload="GET / HTTP/1.1" +var canonicalTCP = []byte{ + 0x45, 0x00, 0x00, 0x36, 0x12, 0x34, 0x40, 0x00, 0x40, 0x06, 0x5c, 0xe4, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x12, 0x34, 0x56, 0x78, 0x00, 0x00, 0x00, 0x00, + 0x50, 0x02, 0xff, 0xff, 0x86, 0x68, 0x00, 0x00, 0x47, 0x45, 0x54, 0x20, 0x2f, 0x20, 0x48, 0x54, + 0x54, 0x50, 0x2f, 0x31, 0x2e, 0x31, +} + +// canonicalICMP: src=10.0.0.1 dst=192.168.1.1 proto=ICMP echo, id=0x1234 seq=1 +var canonicalICMP = []byte{ + 0x45, 0x00, 0x00, 0x1c, 0x12, 0x34, 0x40, 0x00, 0x40, 0x01, 0x5d, 0x03, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x08, 0x00, 0xe5, 0xca, 0x12, 0x34, 0x00, 0x01, +} + +// canonicalUDPReply: src=192.168.1.1:80 dst=169.254.55.96:55555 proto=UDP payload="reply" +var canonicalUDPReply = []byte{ + 0x45, 0x00, 0x00, 0x21, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x85, 0x90, 0xc0, 0xa8, 0x01, 0x01, + 0xa9, 0xfe, 0x37, 0x60, 0x00, 0x50, 0xd9, 0x03, 0x00, 0x0d, 0x27, 0xa6, 0x72, 0x65, 0x70, 0x6c, + 0x79, +} + +// canonicalUDPTest: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="test" +var canonicalUDPTest = []byte{ + 0x45, 0x00, 0x00, 0x20, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xef, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x0c, 0x1b, 0xc9, 0x74, 0x65, 0x73, 0x74, +} + +// canonicalUDPHijack: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="hijack" +var canonicalUDPHijack = []byte{ + 0x45, 0x00, 0x00, 0x22, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xed, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x0e, 0xcd, 0x68, 0x68, 0x69, 0x6a, 0x61, + 0x63, 0x6b, +} + +// canonicalUDPBlocked: src=10.0.0.1:12345 dst=192.168.1.1:443 proto=UDP payload="blocked" +var canonicalUDPBlocked = []byte{ + 0x45, 0x00, 0x00, 0x23, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xec, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x01, 0xbb, 0x00, 0x0f, 0x60, 0xfc, 0x62, 0x6c, 0x6f, 0x63, + 0x6b, 0x65, 0x64, +} + +// canonicalUDPWrongDest: src=10.0.0.1:12345 dst=172.16.0.1:80 proto=UDP payload="wrong dest" +var canonicalUDPWrongDest = []byte{ + 0x45, 0x00, 0x00, 0x26, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x72, 0x81, 0x0a, 0x00, 0x00, 0x01, + 0xac, 0x10, 0x00, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x12, 0xf3, 0x53, 0x77, 0x72, 0x6f, 0x6e, + 0x67, 0x20, 0x64, 0x65, 0x73, 0x74, +} + +// canonicalUDPNoSnat: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="no snat" +var canonicalUDPNoSnat = []byte{ + 0x45, 0x00, 0x00, 0x23, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xec, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x0f, 0x92, 0x58, 0x6e, 0x6f, 0x20, 0x73, + 0x6e, 0x61, 0x74, +} + +// canonicalUDPV4Traffic: src=10.128.0.2:12345 dst=192.168.1.1:80 proto=UDP payload="v4 traffic" +var canonicalUDPV4Traffic = []byte{ + 0x45, 0x00, 0x00, 0x26, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0x68, 0x0a, 0x80, 0x00, 0x02, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x12, 0x2a, 0x42, 0x76, 0x34, 0x20, 0x74, + 0x72, 0x61, 0x66, 0x66, 0x69, 0x63, +} + +// canonicalUDPRoundtrip: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="roundtrip" +var canonicalUDPRoundtrip = []byte{ + 0x45, 0x00, 0x00, 0x25, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xea, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x11, 0xd4, 0xdc, 0x72, 0x6f, 0x75, 0x6e, + 0x64, 0x74, 0x72, 0x69, 0x70, +} + +// canonicalUDPSnatMe: src=10.0.0.1:12345 dst=192.168.1.1:80 proto=UDP payload="snat me" +var canonicalUDPSnatMe = []byte{ + 0x45, 0x00, 0x00, 0x23, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xec, 0x0a, 0x00, 0x00, 0x01, + 0xc0, 0xa8, 0x01, 0x01, 0x30, 0x39, 0x00, 0x50, 0x00, 0x0f, 0xa9, 0x4c, 0x73, 0x6e, 0x61, 0x74, + 0x20, 0x6d, 0x65, +} + +// Expected outputs after rewriting — built from scratch with the post-rewrite +// addresses, so all checksums are independently correct. + +// canonicalUDPSnatted: canonicalUDP with src rewritten to 169.254.55.96:55555 +var canonicalUDPSnatted = []byte{ + 0x45, 0x00, 0x00, 0x27, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x85, 0x8a, 0xa9, 0xfe, 0x37, 0x60, + 0xc0, 0xa8, 0x01, 0x01, 0xd9, 0x03, 0x00, 0x50, 0x00, 0x13, 0xf1, 0x9d, 0x68, 0x65, 0x6c, 0x6c, + 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, +} + +// canonicalUDPReplyUnSnatted: canonicalUDPReply with dst rewritten from 169.254.55.96:55555 to 10.0.0.1:12345 +var canonicalUDPReplyUnSnatted = []byte{ + 0x45, 0x00, 0x00, 0x21, 0x12, 0x34, 0x40, 0x00, 0x40, 0x11, 0x5c, 0xee, 0xc0, 0xa8, 0x01, 0x01, + 0x0a, 0x00, 0x00, 0x01, 0x00, 0x50, 0x30, 0x39, 0x00, 0x0d, 0xa7, 0xce, 0x72, 0x65, 0x70, 0x6c, + 0x79, +} + +// canonicalTCPSnatted: canonicalTCP with src rewritten to 169.254.55.96:55555 +var canonicalTCPSnatted = []byte{ + 0x45, 0x00, 0x00, 0x36, 0x12, 0x34, 0x40, 0x00, 0x40, 0x06, 0x85, 0x86, 0xa9, 0xfe, 0x37, 0x60, + 0xc0, 0xa8, 0x01, 0x01, 0xd9, 0x03, 0x00, 0x50, 0x12, 0x34, 0x56, 0x78, 0x00, 0x00, 0x00, 0x00, + 0x50, 0x02, 0xff, 0xff, 0x06, 0x40, 0x00, 0x00, 0x47, 0x45, 0x54, 0x20, 0x2f, 0x20, 0x48, 0x54, + 0x54, 0x50, 0x2f, 0x31, 0x2e, 0x31, +} + +// canonicalICMPSnatted: canonicalICMP with src rewritten to 169.254.55.96, id changed from 0x1234 to 0x5678 +var canonicalICMPSnatted = []byte{ + 0x45, 0x00, 0x00, 0x1c, 0x12, 0x34, 0x40, 0x00, 0x40, 0x01, 0x85, 0xa5, 0xa9, 0xfe, 0x37, 0x60, + 0xc0, 0xa8, 0x01, 0x01, 0x08, 0x00, 0xa1, 0x86, 0x56, 0x78, 0x00, 0x01, +} + +func TestCalcNewTransportChecksum_Identity(t *testing.T) { + // Rewriting to the same IP/port should return the same checksum + ip := netip.MustParseAddr("10.0.0.1") + result := calcNewTransportChecksum(0x1234, ip, 80, ip, 80) + assert.Equal(t, uint16(0x1234), result) +} + +func TestCalcNewTransportChecksum_VsCanonical(t *testing.T) { + srcIP := netip.MustParseAddr("10.0.0.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + // Extract the original UDP checksum from canonicalUDP (bytes 26-27) + origChecksum := binary.BigEndian.Uint16(canonicalUDP[26:28]) + + // Compute incrementally + incremental := calcNewTransportChecksum(origChecksum, srcIP, 12345, snatIP, 55555) + + // Verify it matches the checksum in the independently-computed canonicalUDPSnatted + expectedChecksum := binary.BigEndian.Uint16(canonicalUDPSnatted[26:28]) + assert.Equal(t, expectedChecksum, incremental, "incremental checksum should match canonical expected output") +} + +func TestCalcNewICMPChecksum_Identity(t *testing.T) { + // Same values in and out should be identity + result := calcNewICMPChecksum(0xABCD, 0, 0, 1234, 1234) + assert.Equal(t, uint16(0xABCD), result) +} + +func TestRewritePacket_UDP(t *testing.T) { + srcIP := netip.MustParseAddr("10.0.0.1") + dstIP := netip.MustParseAddr("192.168.1.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + pkt := slices.Clone(canonicalUDP) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: srcIP, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + + // SNAT rewrites source: IP at offset 12, port at offset 0 inside transport + oldIP := netip.AddrPortFrom(srcIP, 12345) + newIP := netip.AddrPortFrom(snatIP, 55555) + rewritePacket(pkt, &fp, oldIP, newIP, 12, 0) + + assert.Equal(t, canonicalUDPSnatted, pkt, "rewritten packet should match canonical expected output") +} + +func TestRewritePacket_UDP_UnSNAT(t *testing.T) { + snatIP := netip.MustParseAddr("169.254.55.96") + dstIP := netip.MustParseAddr("192.168.1.1") + origSrcIP := netip.MustParseAddr("10.0.0.1") + + pkt := slices.Clone(canonicalUDPReply) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: snatIP, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + + // UnSNAT rewrites destination: IP at offset 16, port at offset 2 inside transport + oldIP := netip.AddrPortFrom(snatIP, 55555) + newIP := netip.AddrPortFrom(origSrcIP, 12345) + rewritePacket(pkt, &fp, oldIP, newIP, 16, 2) + + assert.Equal(t, canonicalUDPReplyUnSnatted, pkt, "un-SNATted packet should match canonical expected output") +} + +func TestRewritePacket_TCP(t *testing.T) { + srcIP := netip.MustParseAddr("10.0.0.1") + dstIP := netip.MustParseAddr("192.168.1.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + pkt := slices.Clone(canonicalTCP) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: srcIP, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoTCP, + } + + oldIP := netip.AddrPortFrom(srcIP, 12345) + newIP := netip.AddrPortFrom(snatIP, 55555) + rewritePacket(pkt, &fp, oldIP, newIP, 12, 0) + + assert.Equal(t, canonicalTCPSnatted, pkt, "rewritten TCP packet should match canonical expected output") +} + +func TestRewritePacket_ICMP(t *testing.T) { + srcIP := netip.MustParseAddr("10.0.0.1") + dstIP := netip.MustParseAddr("192.168.1.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + pkt := slices.Clone(canonicalICMP) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: srcIP, + LocalPort: 0, + RemotePort: 0x1234, // ICMP ID used as port + Protocol: firewall.ProtoICMP, + } + + oldIP := netip.AddrPortFrom(srcIP, 0x1234) + newIP := netip.AddrPortFrom(snatIP, 0x5678) + rewritePacket(pkt, &fp, oldIP, newIP, 12, 0) + + assert.Equal(t, canonicalICMPSnatted, pkt, "rewritten ICMP packet should match canonical expected output") +} + +func TestRewritePacket_Roundtrip(t *testing.T) { + // Test that SNAT followed by unSNAT produces the original packet + srcIP := netip.MustParseAddr("10.0.0.1") + dstIP := netip.MustParseAddr("192.168.1.1") + snatIP := netip.MustParseAddr("169.254.55.96") + + pkt := slices.Clone(canonicalUDPRoundtrip) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: srcIP, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + + // SNAT: rewrite source + oldSrc := netip.AddrPortFrom(srcIP, 12345) + newSrc := netip.AddrPortFrom(snatIP, 55555) + rewritePacket(pkt, &fp, oldSrc, newSrc, 12, 0) + + // Verify intermediate state is not the original + require.NotEqual(t, canonicalUDPRoundtrip, pkt) + + // UnSNAT: rewrite source back + rewritePacket(pkt, &fp, newSrc, oldSrc, 12, 0) + + // Packet should be byte-for-byte identical to original + assert.Equal(t, canonicalUDPRoundtrip, pkt, "packet should be identical after roundtrip SNAT/unSNAT") +} + +func TestSnatInfo_Valid(t *testing.T) { + t.Run("nil is invalid", func(t *testing.T) { + var s *snatInfo + assert.False(t, s.Valid()) + }) + + t.Run("zero value is invalid", func(t *testing.T) { + s := &snatInfo{} + assert.False(t, s.Valid()) + }) + + t.Run("with valid src is valid", func(t *testing.T) { + s := &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 1234), + SrcVpnIp: netip.MustParseAddr("fd00::1"), + SnatPort: 55555, + } + assert.True(t, s.Valid()) + }) +} + +func TestFirewall_ShouldUnSNAT(t *testing.T) { + snatAddr := netip.MustParseAddr("169.254.55.96") + + t.Run("no snat addr configured", func(t *testing.T) { + fw := &Firewall{} + fp := &firewall.Packet{RemoteAddr: snatAddr} + assert.False(t, fw.ShouldUnSNAT(fp)) + }) + + t.Run("packet to snat addr", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + fp := &firewall.Packet{RemoteAddr: snatAddr} + assert.True(t, fw.ShouldUnSNAT(fp)) + }) + + t.Run("packet to different addr", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + fp := &firewall.Packet{RemoteAddr: netip.MustParseAddr("10.0.0.1")} + assert.False(t, fw.ShouldUnSNAT(fp)) + }) +} + +func TestFirewall_IdentifyNetworkType_SNATPeer(t *testing.T) { + snatAddr := netip.MustParseAddr("169.254.55.96") + + t.Run("v4 packet from v6-only host without networks table", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("fd00::1")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + assert.Equal(t, NetworkTypeUncheckedSNATPeer, fw.identifyNetworkType(h, fp)) + }) + + t.Run("v4 packet from v4 host is not snat peer", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.1")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + assert.Equal(t, NetworkTypeVPN, fw.identifyNetworkType(h, fp)) + }) + + t.Run("v6 packet from v6 host is VPN", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("fd00::1")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("fd00::1"), + LocalAddr: netip.MustParseAddr("fd00::2"), + } + assert.Equal(t, NetworkTypeVPN, fw.identifyNetworkType(h, fp)) + }) + + t.Run("mismatched v4 from v4 host is invalid", func(t *testing.T) { + fw := &Firewall{snatAddr: snatAddr} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.1")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.0.0.99"), + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + assert.Equal(t, NetworkTypeInvalidPeer, fw.identifyNetworkType(h, fp)) + }) +} + +func TestFirewall_AllowNetworkType_SNAT(t *testing.T) { + t.Run("snat peer allowed with snat addr", func(t *testing.T) { + fw := &Firewall{snatAddr: netip.MustParseAddr("169.254.55.96")} + assert.NoError(t, fw.allowNetworkType(NetworkTypeUncheckedSNATPeer)) + }) + + t.Run("snat peer rejected without snat addr", func(t *testing.T) { + fw := &Firewall{} + assert.ErrorIs(t, fw.allowNetworkType(NetworkTypeUncheckedSNATPeer), ErrInvalidRemoteIP) + }) + + t.Run("vpn always allowed", func(t *testing.T) { + fw := &Firewall{} + assert.NoError(t, fw.allowNetworkType(NetworkTypeVPN)) + }) + + t.Run("unsafe always allowed", func(t *testing.T) { + fw := &Firewall{} + assert.NoError(t, fw.allowNetworkType(NetworkTypeUnsafe)) + }) + + t.Run("invalid peer rejected", func(t *testing.T) { + fw := &Firewall{} + assert.ErrorIs(t, fw.allowNetworkType(NetworkTypeInvalidPeer), ErrInvalidRemoteIP) + }) + + t.Run("vpn peer rejected", func(t *testing.T) { + fw := &Firewall{} + assert.ErrorIs(t, fw.allowNetworkType(NetworkTypeVPNPeer), ErrPeerRejected) + }) +} + +func TestFirewall_FindUsableSNATPort(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + t.Run("finds first available port", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + err := fw.findUsableSNATPort(&fp, cn) + require.NoError(t, err) + // Port should have been assigned + assert.Equal(t, uint16(12345), fp.RemotePort, "should use original port if available") + }) + + t.Run("skips occupied port", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + // Occupy the port + fw.Conntrack.Lock() + fw.Conntrack.Conns[fp] = &conn{} + fw.Conntrack.Unlock() + + cn := &conn{} + err := fw.findUsableSNATPort(&fp, cn) + require.NoError(t, err) + assert.NotEqual(t, uint16(12345), fp.RemotePort, "should pick a different port") + }) + + t.Run("returns error on exhaustion", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + + // Fill all 0x7ff ports + baseFP := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + Protocol: firewall.ProtoUDP, + } + fw.Conntrack.Lock() + for i := 0; i < 0x7ff; i++ { + fp := baseFP + fp.RemotePort = uint16(0x7ff + i) + fw.Conntrack.Conns[fp] = &conn{} + } + fw.Conntrack.Unlock() + + // Try to find a port starting from 0x7ff + fp := baseFP + fp.RemotePort = 0x7ff + cn := &conn{} + err := fw.findUsableSNATPort(&fp, cn) + assert.ErrorIs(t, err, ErrCannotSNAT) + }) +} + +func TestFirewall_ApplySnat(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + peerV6Addr := netip.MustParseAddr("fd00::1") + dstIP := netip.MustParseAddr("192.168.1.1") + + t.Run("new flow from v6 host", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + h := &HostInfo{vpnAddrs: []netip.Addr{peerV6Addr}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.NoError(t, err) + + // Should have created snat info + require.True(t, cn.snat.Valid()) + assert.Equal(t, peerV6Addr, cn.snat.SrcVpnIp) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), cn.snat.Src.Addr()) + assert.Equal(t, uint16(12345), cn.snat.Src.Port()) + + // Packet source should be rewritten to snatAddr + gotSrcIP, _ := netip.AddrFromSlice(pkt[12:16]) + assert.Equal(t, snatAddr, gotSrcIP) + }) + + t.Run("existing flow with matching identity", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 12345), + SrcVpnIp: peerV6Addr, + SnatPort: 55555, + }, + } + h := &HostInfo{vpnAddrs: []netip.Addr{peerV6Addr}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.NoError(t, err) + + // Source should be rewritten + gotSrcIP, _ := netip.AddrFromSlice(pkt[12:16]) + assert.Equal(t, snatAddr, gotSrcIP) + }) + + t.Run("identity mismatch rejected", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 12345), + SrcVpnIp: netip.MustParseAddr("fd00::99"), // Different VPN IP + SnatPort: 55555, + }, + } + // Attacker has a different VPN address + h := &HostInfo{vpnAddrs: []netip.Addr{peerV6Addr}} + + err := fw.applySnat(pkt, &fp, cn, h) + assert.ErrorIs(t, err, ErrSNATIdentityMismatch) + }) + + t.Run("no snat addr configured", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + h := &HostInfo{vpnAddrs: []netip.Addr{peerV6Addr}} + + err := fw.applySnat(pkt, &fp, cn, h) + assert.ErrorIs(t, err, ErrCannotSNAT) + }) + + t.Run("v4 host rejected for new flow", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + // This host has a v4 address - can't SNAT for it + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.50")}} + + err := fw.applySnat(pkt, &fp, cn, h) + assert.ErrorIs(t, err, ErrCannotSNAT) + }) +} + +func TestFirewall_UnSnat(t *testing.T) { + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + peerV6Addr := netip.MustParseAddr("fd00::1") + origSrcIP := netip.MustParseAddr("10.0.0.1") + + t.Run("successful unsnat", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + // Create a conntrack entry for the snatted flow + snatFP := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + fw.Conntrack.Lock() + fw.Conntrack.Conns[snatFP] = &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(origSrcIP, 12345), + SrcVpnIp: peerV6Addr, + SnatPort: 55555, + }, + } + fw.Conntrack.Unlock() + + pkt := slices.Clone(canonicalUDPReply) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + + result := fw.unSnat(pkt, &fp) + assert.True(t, result.IsValid()) + assert.Equal(t, peerV6Addr, result) + + // Destination should be rewritten to the original source + gotDstIP, _ := netip.AddrFromSlice(pkt[16:20]) + assert.Equal(t, origSrcIP, gotDstIP) + }) + + t.Run("no conntrack entry", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPReply) + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + + result := fw.unSnat(pkt, &fp) + assert.False(t, result.IsValid()) + }) +} + +func TestFirewall_Drop_SNATFullFlow(t *testing.T) { + // Integration test: a complete SNAT flow through Drop + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + myV6Prefix := netip.MustParsePrefix("fd00::1/128") + unsafeNet := netip.MustParsePrefix("192.168.0.0/16") + + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{myV6Prefix}, + unsafeNetworks: []netip.Prefix{unsafeNet}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw.snatAddr = snatAddr + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) + + // Set up the peer: an IPv6-only host sending IPv4 traffic + peerV6Addr := netip.MustParseAddr("fd00::2") + peerCert := &dummyCert{ + name: "peer", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::2/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(myV6Prefix) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV6Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + pkt := slices.Clone(canonicalUDPSnatMe) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + // Drop should succeed and SNAT the packet + err := fw.Drop(fp, pkt, true, h, cp, nil) + require.NoError(t, err) + + // After Drop, the source should be rewritten to the snat addr + gotSrcIP, _ := netip.AddrFromSlice(pkt[12:16]) + assert.Equal(t, snatAddr, gotSrcIP) +} + +func TestHasOnlyV6Addresses(t *testing.T) { + t.Run("v6 only", func(t *testing.T) { + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("fd00::1"), + netip.MustParseAddr("fd00::2"), + }} + assert.True(t, h.HasOnlyV6Addresses()) + }) + + t.Run("v4 only", func(t *testing.T) { + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("10.0.0.1"), + }} + assert.False(t, h.HasOnlyV6Addresses()) + }) + + t.Run("mixed v4 and v6", func(t *testing.T) { + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("fd00::1"), + netip.MustParseAddr("10.0.0.1"), + }} + assert.False(t, h.HasOnlyV6Addresses()) + }) +} + +// --- Adversarial SNAT Tests --- + +func TestFirewall_ApplySnat_CrossHostHijack(t *testing.T) { + // Host A (fd00::1) establishes SNAT flow. Host B (fd00::2) sends a packet + // matching the same conntrack key but with a different identity. + // applySnat must reject with ErrSNATIdentityMismatch and leave the packet unmodified. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + hostA := netip.MustParseAddr("fd00::1") + hostB := netip.MustParseAddr("fd00::2") + + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + // Simulate Host A having established a flow + cn := &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 12345), + SrcVpnIp: hostA, + SnatPort: 55555, + }, + } + + // Host B tries to reuse the same conntrack entry + pkt := slices.Clone(canonicalUDPHijack) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + hB := &HostInfo{vpnAddrs: []netip.Addr{hostB}} + + err := fw.applySnat(pkt, &fp, cn, hB) + assert.ErrorIs(t, err, ErrSNATIdentityMismatch) + assert.Equal(t, canonicalUDPHijack, pkt, "packet bytes must be unmodified after identity mismatch") +} + +func TestFirewall_ApplySnat_MixedStackRejected(t *testing.T) { + // A host with both v4 and v6 VPN addresses should never get SNAT treatment. + // Test both orderings of vpnAddrs to verify behavior. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + dstIP := netip.MustParseAddr("192.168.1.1") + + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + + t.Run("v6 first then v4", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + // Mixed-stack: v6 first. applySnat checks vpnAddrs[0].Is6() which is true, + // so it would create a flow. But the caller (Drop) guards with HasOnlyV6Addresses(). + // This test documents that applySnat alone doesn't prevent mixed-stack SNAT. + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("fd00::1"), + netip.MustParseAddr("10.0.0.50"), + }} + + err := fw.applySnat(pkt, &fp, cn, h) + // applySnat only checks vpnAddrs[0].Is6(), so this succeeds. + // The real guard is in Drop() via HasOnlyV6Addresses(). + assert.NoError(t, err, "applySnat alone allows v6-first mixed-stack (guarded by Drop)") + }) + + t.Run("v4 first then v6", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + // Mixed-stack: v4 first. vpnAddrs[0].Is6() is false -> ErrCannotSNAT. + h := &HostInfo{vpnAddrs: []netip.Addr{ + netip.MustParseAddr("10.0.0.50"), + netip.MustParseAddr("fd00::1"), + }} + + err := fw.applySnat(pkt, &fp, cn, h) + assert.ErrorIs(t, err, ErrCannotSNAT) + assert.Equal(t, canonicalUDPTest, pkt, "packet bytes must be unmodified on error") + }) +} + +func TestFirewall_ApplySnat_PacketUnmodifiedOnError(t *testing.T) { + // When applySnat returns an error, the packet must not be partially rewritten. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + dstIP := netip.MustParseAddr("192.168.1.1") + + t.Run("no snatAddr configured", func(t *testing.T) { + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + + pkt := slices.Clone(canonicalUDPTest) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("fd00::1")}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.Error(t, err) + assert.Equal(t, canonicalUDPTest, pkt, "packet must be byte-for-byte identical after error") + }) + + t.Run("identity mismatch", func(t *testing.T) { + snatAddr := netip.MustParseAddr("169.254.55.96") + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{ + snat: &snatInfo{ + Src: netip.AddrPortFrom(netip.MustParseAddr("10.0.0.1"), 12345), + SrcVpnIp: netip.MustParseAddr("fd00::99"), + SnatPort: 55555, + }, + } + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("fd00::1")}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.ErrorIs(t, err, ErrSNATIdentityMismatch) + assert.Equal(t, canonicalUDPTest, pkt, "packet must be byte-for-byte identical after identity mismatch") + }) + + t.Run("v4 host rejected", func(t *testing.T) { + snatAddr := netip.MustParseAddr("169.254.55.96") + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + pkt := slices.Clone(canonicalUDPTest) + + fp := firewall.Packet{ + LocalAddr: dstIP, + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cn := &conn{} + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.50")}} + + err := fw.applySnat(pkt, &fp, cn, h) + require.ErrorIs(t, err, ErrCannotSNAT) + assert.Equal(t, canonicalUDPTest, pkt, "packet must be byte-for-byte identical after v4 host rejection") + }) +} + +func TestFirewall_UnSnat_NonSNATConntrack(t *testing.T) { + // A conntrack entry exists but has snat=nil. unSnat should return an invalid addr + // and not rewrite the packet. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + c := &dummyCert{ + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw.snatAddr = snatAddr + + // Create a conntrack entry with snat=nil (a normal non-SNAT connection) + snatFP := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + fw.Conntrack.Lock() + fw.Conntrack.Conns[snatFP] = &conn{ + snat: nil, // deliberately nil + } + fw.Conntrack.Unlock() + + pkt := slices.Clone(canonicalUDPReply) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: snatAddr, + LocalPort: 80, + RemotePort: 55555, + Protocol: firewall.ProtoUDP, + } + + result := fw.unSnat(pkt, &fp) + assert.False(t, result.IsValid(), "unSnat should return invalid addr for non-SNAT conntrack entry") + assert.Equal(t, canonicalUDPReply, pkt, "packet must not be rewritten when conntrack has no snat info") +} + +func TestFirewall_Drop_FirewallBlocksSNAT(t *testing.T) { + // Firewall rules only allow port 80. An SNAT-eligible packet to port 443 + // must be rejected with ErrNoMatchingRule BEFORE any SNAT rewriting occurs. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw.snatAddr = snatAddr + // Only allow port 80 inbound + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 80, []string{"any"}, "", "", "any", "", "")) + + peerV6Addr := netip.MustParseAddr("fd00::2") + peerCert := &dummyCert{ + name: "peer", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::2/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::1/128")) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV6Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + // Send to port 443 (not allowed) + pkt := slices.Clone(canonicalUDPBlocked) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 443, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + err := fw.Drop(fp, pkt, true, h, cp, nil) + assert.ErrorIs(t, err, ErrNoMatchingRule, "firewall should block SNAT-eligible traffic that doesn't match rules") + assert.Equal(t, canonicalUDPBlocked, pkt, "packet must not be rewritten when firewall blocks it") +} + +func TestFirewall_Drop_SNATLocalAddrNotRoutable(t *testing.T) { + // An SNAT peer sends IPv4 traffic to an address NOT in routableNetworks. + // willingToHandleLocalAddr should reject with ErrInvalidLocalIP. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw.snatAddr = snatAddr + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) + + peerV6Addr := netip.MustParseAddr("fd00::2") + peerCert := &dummyCert{ + name: "peer", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::2/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::1/128")) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV6Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + // Dest 172.16.0.1 is NOT in our routableNetworks (which only has fd00::1/128 and 192.168.0.0/16) + pkt := slices.Clone(canonicalUDPWrongDest) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("172.16.0.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + err := fw.Drop(fp, pkt, true, h, cp, nil) + assert.ErrorIs(t, err, ErrInvalidLocalIP, "traffic to non-routable local address should be rejected") +} + +func TestFirewall_Drop_NoSnatAddrRejectsV6Peer(t *testing.T) { + // Firewall has no snatAddr configured. An IPv6-only peer sends IPv4 traffic. + // allowNetworkType(UncheckedSNATPeer) should reject with ErrInvalidRemoteIP. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, netip.Addr{}) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) + + peerV6Addr := netip.MustParseAddr("fd00::2") + peerCert := &dummyCert{ + name: "peer", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::2/128")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::1/128")) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV6Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + pkt := slices.Clone(canonicalUDPNoSnat) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.0.0.1"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + err := fw.Drop(fp, pkt, true, h, cp, nil) + assert.ErrorIs(t, err, ErrInvalidRemoteIP, "v6 peer with no snatAddr should be rejected") +} + +func TestFirewall_Drop_IPv4HostNotSNATted(t *testing.T) { + // An IPv4 VPN host sends IPv4 traffic. Even though the router has snatAddr + // configured and the traffic is IPv4, the firewall must NOT treat this as + // UncheckedSNATPeer. The packet must not be SNAT-rewritten. + l := logrus.New() + l.SetLevel(logrus.PanicLevel) + + snatAddr := netip.MustParseAddr("169.254.55.96") + + t.Run("v6-only router rejects v4 peer as VPNPeer", func(t *testing.T) { + // When the router is v6-only, the v4 peer's address is outside our VPN + // networks -> classified as NetworkTypeVPNPeer -> rejected (not SNATted). + myCert := &dummyCert{ + name: "me", + networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw.snatAddr = snatAddr + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) + + peerV4Addr := netip.MustParseAddr("10.128.0.2") + peerCert := &dummyCert{ + name: "v4peer", + networks: []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::1/128")) + + h := &HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: peerCert, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{peerV4Addr}, + } + h.buildNetworks(myVpnNetworksTable, peerCert) + + pkt := slices.Clone(canonicalUDPV4Traffic) + + fp := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.168.1.1"), + RemoteAddr: netip.MustParseAddr("10.128.0.2"), + LocalPort: 80, + RemotePort: 12345, + Protocol: firewall.ProtoUDP, + } + cp := cert.NewCAPool() + + err := fw.Drop(fp, pkt, true, h, cp, nil) + assert.ErrorIs(t, err, ErrPeerRejected, "IPv4 peer should be rejected as VPNPeer, not treated as SNAT") + assert.Equal(t, canonicalUDPV4Traffic, pkt, "packet must not be rewritten when peer is rejected") + }) + + t.Run("identifyNetworkType classifies v4 peer correctly", func(t *testing.T) { + // Directly verify that identifyNetworkType returns the right type for + // an IPv4 peer (not UncheckedSNATPeer). + fw := &Firewall{snatAddr: snatAddr} + + // Simple case: v4 host, no networks table + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.128.0.2")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.128.0.2"), + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + nwType := fw.identifyNetworkType(h, fp) + assert.Equal(t, NetworkTypeVPN, nwType, "v4 peer using its own VPN addr should be NetworkTypeVPN") + assert.NotEqual(t, NetworkTypeUncheckedSNATPeer, nwType, "must NOT be classified as SNAT peer") + }) + + t.Run("identifyNetworkType v4 peer with mismatched source", func(t *testing.T) { + // v4 host sends with a source IP that doesn't match its VPN addr + fw := &Firewall{snatAddr: snatAddr} + + h := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("10.128.0.2")}} + fp := firewall.Packet{ + RemoteAddr: netip.MustParseAddr("10.0.0.99"), // Not the peer's VPN addr + LocalAddr: netip.MustParseAddr("192.168.1.1"), + } + nwType := fw.identifyNetworkType(h, fp) + assert.Equal(t, NetworkTypeInvalidPeer, nwType, "v4 peer with mismatched source should be InvalidPeer") + assert.NotEqual(t, NetworkTypeUncheckedSNATPeer, nwType, "must NOT be classified as SNAT peer") + }) +} From 064153f0c23d35a8667a86b6621f9267388f57c7 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 19 Feb 2026 14:18:09 -0600 Subject: [PATCH 09/31] split the client-snat-addr and the router-snat-addr to decrease confusion hopefully --- firewall.go | 79 +++++++++++++++++++++++---------------- overlay/device.go | 1 + overlay/tun.go | 81 ++++++++++++++++++++++++++-------------- overlay/tun_android.go | 23 +++++++++--- overlay/tun_darwin.go | 30 ++++++++------- overlay/tun_disabled.go | 18 +++++---- overlay/tun_freebsd.go | 28 ++++++++------ overlay/tun_ios.go | 19 +++++++--- overlay/tun_linux.go | 32 +++++++++------- overlay/tun_netbsd.go | 28 ++++++++------ overlay/tun_openbsd.go | 28 ++++++++------ overlay/tun_snat_test.go | 28 +++++++++----- overlay/tun_tester.go | 22 +++++++---- overlay/tun_windows.go | 28 ++++++++------ overlay/user.go | 3 ++ snat_test.go | 49 ++++++++++++------------ test/tun.go | 4 ++ 17 files changed, 304 insertions(+), 197 deletions(-) diff --git a/firewall.go b/firewall.go index f5137946..b2d15741 100644 --- a/firewall.go +++ b/firewall.go @@ -89,6 +89,7 @@ type Firewall struct { defaultLocalCIDRAny bool incomingMetrics firewallMetrics outgoingMetrics firewallMetrics + unsafeIPv4Origin netip.Addr snatAddr netip.Addr l *logrus.Logger @@ -182,14 +183,12 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D tmax = defaultTimeout } - hasV4Networks := false routableNetworks := new(bart.Lite) var assignedNetworks []netip.Prefix for _, network := range c.Networks() { nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) routableNetworks.Insert(nprefix) assignedNetworks = append(assignedNetworks, network) - hasV4Networks = hasV4Networks || network.Addr().Is4() } hasUnsafeNetworks := false @@ -198,10 +197,6 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D hasUnsafeNetworks = true } - if !hasUnsafeNetworks || hasV4Networks { - snatAddr = netip.Addr{} //disable using the special snat address if it doesn't make sense to use it - } - return &Firewall{ Conntrack: &FirewallConntrack{ Conns: make(map[firewall.Packet]*conn), @@ -356,9 +351,9 @@ func (f *Firewall) GetRuleHashes() string { func (f *Firewall) SetSNATAddressFromInterface(i *Interface) { //address-mutation-avoidance is done inside Interface, the firewall doesn't need to care //todo should snatted conntracks get expired out? Probably not needed until if/when we allow reload - if f.hasUnsafeNetworks { //todo this logic??? - f.snatAddr = i.inside.SNATAddress().Addr() - } + f.snatAddr = i.inside.SNATAddress().Addr() + f.unsafeIPv4Origin = i.inside.UnsafeIPv4OriginAddress().Addr() + //f.routableNetworks.Insert(i.inside.UnsafeIPv4OriginAddress()) //todo is this the right idea? } func (f *Firewall) ShouldUnSNAT(fp *firewall.Packet) bool { @@ -560,27 +555,26 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo return nil } -func (f *Firewall) identifyNetworkType(h *HostInfo, fp firewall.Packet) NetworkType { +func (f *Firewall) identifyRemoteNetworkType(h *HostInfo, fp firewall.Packet) NetworkType { if h.networks == nil { // Simple case: Certificate has one address and no unsafe networks if h.vpnAddrs[0] == fp.RemoteAddr { return NetworkTypeVPN - } else if fp.IsIPv4() && h.HasOnlyV6Addresses() { - return NetworkTypeUncheckedSNATPeer - } else { - return NetworkTypeInvalidPeer - } + } //else, fallthrough } else if nwType, ok := h.networks.Lookup(fp.RemoteAddr); ok { //todo check for if fp.RemoteAddr is our f.snatAddr here too? Does that need a special case? return nwType //will return NetworkTypeVPN or NetworkTypeUnsafe - } else if fp.IsIPv4() && h.HasOnlyV6Addresses() { //todo surely I'm smart enough to avoid writing these branches twice + } + + //RemoteAddr not in our networks table + if f.snatAddr.IsValid() && fp.IsIPv4() && h.HasOnlyV6Addresses() { return NetworkTypeUncheckedSNATPeer } else { return NetworkTypeInvalidPeer } } -func (f *Firewall) allowNetworkType(nwType NetworkType) error { +func (f *Firewall) allowRemoteNetworkType(nwType NetworkType, fp firewall.Packet) error { switch nwType { case NetworkTypeVPN: return nil @@ -592,7 +586,10 @@ func (f *Firewall) allowNetworkType(nwType NetworkType) error { case NetworkTypeUnsafe: return nil // nothing special, one day this may have different FW rules case NetworkTypeUncheckedSNATPeer: - if f.snatAddr.IsValid() { + if f.unsafeIPv4Origin.IsValid() && fp.LocalAddr == f.unsafeIPv4Origin { + return nil //the client case + } + if f.snatAddr.IsValid() { //todo return nil //todo is this enough? } else { return ErrInvalidRemoteIP @@ -606,21 +603,37 @@ func (f *Firewall) willingToHandleLocalAddr(incoming bool, fp firewall.Packet, r if f.routableNetworks.Contains(fp.LocalAddr) { return nil //easy, this should handle NetworkTypeVPN in all cases, and NetworkTypeUnsafe on the router side } - - //watch out, when incoming, this function decides if we will deliver a packet locally - //when outgoing, much less important, it just decides if we're willing to tx - switch remoteNwType { - // we never want to accept unconntracked inbound traffic from these network types, but outbound is okay. - // It's the recipient's job to validate and accept or deny the packet. - case NetworkTypeUncheckedSNATPeer, NetworkTypeUnsafe: - //NetworkTypeUnsafe needed here to allow inbound from an unsafe-router - if incoming { - return ErrInvalidLocalIP - } - return nil - default: + if incoming { //at least for now, reject all traffic other than what we've already decided is routable return ErrInvalidLocalIP } + + //now, all traffic is outgoing. Outgoing traffic to these types is not required to be considered inbound-routable + //todo is this right??? can/should these rules be tighter? + if remoteNwType == NetworkTypeUnsafe { + return nil + } + //if remoteNwType == NetworkTypeUncheckedSNATPeer { + // return nil + //} + + //todo + + ////watch out, when incoming, this function decides if we will deliver a packet locally + ////when outgoing, much less important, it just decides if we're willing to tx + //switch remoteNwType { + //// we never want to accept unconntracked inbound traffic from these network types, but outbound is okay. + //// It's the recipient's job to validate and accept or deny the packet. + //case NetworkTypeUncheckedSNATPeer, NetworkTypeUnsafe: + // //NetworkTypeUnsafe needed here to allow inbound from an unsafe-router + // if incoming { + // return ErrInvalidLocalIP + // } + // return nil + //default: + // return ErrInvalidLocalIP + //} + + return ErrInvalidLocalIP } // Drop returns an error if the packet should be dropped, explaining why. It @@ -654,8 +667,8 @@ func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostIn } // Make sure remote address matches nebula certificate, and determine how to treat it - remoteNetworkType := f.identifyNetworkType(h, fp) - if err := f.allowNetworkType(remoteNetworkType); err != nil { + remoteNetworkType := f.identifyRemoteNetworkType(h, fp) + if err := f.allowRemoteNetworkType(remoteNetworkType, fp); err != nil { f.metrics(incoming).droppedRemoteAddr.Inc(1) return err } diff --git a/overlay/device.go b/overlay/device.go index bb14a76c..0f2f44c2 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -12,6 +12,7 @@ type Device interface { Activate() error Networks() []netip.Prefix UnsafeNetworks() []netip.Prefix + UnsafeIPv4OriginAddress() netip.Prefix SNATAddress() netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways diff --git a/overlay/tun.go b/overlay/tun.go index 8bac6502..8ca6f537 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -131,52 +131,75 @@ func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, er return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest) } -func prepareSnatAddr(d Device, l *logrus.Logger, c *config.C, routes []Route) netip.Prefix { +func genLinkLocal() netip.Prefix { + octets := []byte{169, 254, 0, 0} + _, _ = rand.Read(octets[2:4]) + if octets[3] == 0 { + octets[3] = 1 //please no .0 addresses + } else if octets[2] == 255 && octets[3] == 255 { + octets[3] = 254 //please no broadcast addresses + } + out, _ := netip.AddrFromSlice(octets) + return netip.PrefixFrom(out, 32) +} + +// prepareUnsafeOriginAddr provides the IPv4 address used on IPv6-only clients that need to access IPv4 unsafe routes +func prepareUnsafeOriginAddr(d Device, l *logrus.Logger, c *config.C, routes []Route) netip.Prefix { + if !d.Networks()[0].Addr().Is6() { + return netip.Prefix{} //if we have an IPv4 assignment within the overlay, we don't need an unsafe origin address + } + + needed := false + for _, route := range routes { //or if we have a route defined into an IPv4 range + if route.Cidr.Addr().Is4() { + needed = true //todo should this only apply to unsafe routes? almost certainly + break + } + } + if !needed { + return netip.Prefix{} + } + + //todo better config name for sure + if a := c.GetString("tun.unsafe_origin_address_for_4over6", ""); a != "" { + out, err := netip.ParseAddr(a) + if err != nil { + l.WithField("value", a).WithError(err).Warn("failed to parse tun.unsafe_origin_address_for_4over6, will use a random value") + } else if !out.Is4() || !out.IsLinkLocalUnicast() { + l.WithField("value", out).Warn("tun.unsafe_origin_address_for_4over6 must be an IPv4 address") + } else if out.IsValid() { + return netip.PrefixFrom(out, 32) + } + } + return genLinkLocal() +} + +// prepareSnatAddr provides the address that an IPv6-only unsafe router should use to SNAT traffic before handing it to the operating system +func prepareSnatAddr(d Device, l *logrus.Logger, c *config.C) netip.Prefix { if !d.Networks()[0].Addr().Is6() { return netip.Prefix{} //if we have an IPv4 assignment within the overlay, we don't need a snat address } - addSnatAddr := false + needed := false for _, un := range d.UnsafeNetworks() { //if we are an unsafe router for an IPv4 range if un.Addr().Is4() { - addSnatAddr = true + needed = true break } } - for _, route := range routes { //or if we have a route defined into an IPv4 range - if route.Cidr.Addr().Is4() { - addSnatAddr = true //todo should this only apply to unsafe routes? - break - } - } - if !addSnatAddr { + if !needed { return netip.Prefix{} } - var err error - out := netip.Addr{} if a := c.GetString("tun.snat_address_for_4over6", ""); a != "" { - out, err = netip.ParseAddr(a) + out, err := netip.ParseAddr(a) if err != nil { l.WithField("value", a).WithError(err).Warn("failed to parse tun.snat_address_for_4over6, will use a random value") } else if !out.Is4() || !out.IsLinkLocalUnicast() { l.WithField("value", out).Warn("tun.snat_address_for_4over6 must be an IPv4 address") + } else if out.IsValid() { + return netip.PrefixFrom(out, 32) } } - if !out.IsValid() { - octets := []byte{169, 254, 0, 0} - _, _ = rand.Read(octets[2:4]) - if octets[3] == 0 { - octets[3] = 1 //please no .0 addresses - } else if octets[2] == 255 && octets[3] == 255 { - octets[3] = 254 //please no broadcast addresses - } - ok := false - out, ok = netip.AddrFromSlice(octets) - if !ok { - l.Error("failed to produce a valid IPv4 address for tun.snat_address_for_4over6") - return netip.Prefix{} - } - } - return netip.PrefixFrom(out, 32) + return genLinkLocal() } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 3ab3f8a7..d1434890 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -19,12 +19,13 @@ import ( type tun struct { io.ReadWriteCloser - fd int - vpnNetworks []netip.Prefix - unsafeNetworks []netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + fd int + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger } func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) { @@ -78,6 +79,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err @@ -97,6 +100,14 @@ func (t *tun) UnsafeNetworks() []netip.Prefix { return t.UnsafeNetworks() } +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + +func (t *tun) SNATAddress() netip.Prefix { + return netip.Prefix{} +} + func (t *tun) Name() string { return "android" } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 0ab331bb..1911564a 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -24,15 +24,15 @@ import ( type tun struct { io.ReadWriteCloser - Device string - vpnNetworks []netip.Prefix - unsafeNetworks []netip.Prefix - snatAddr netip.Prefix - DefaultMTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - linkAddr *netroute.LinkAddr - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + DefaultMTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + linkAddr *netroute.LinkAddr + l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte @@ -216,8 +216,8 @@ func (t *tun) Activate() error { } } } - if t.snatAddr.IsValid() && t.snatAddr.Addr().Is4() { - if err = t.activate4(t.snatAddr); err != nil { + if t.unsafeIPv4Origin.IsValid() && t.unsafeIPv4Origin.Addr().Is4() { + if err = t.activate4(t.unsafeIPv4Origin); err != nil { return err } } @@ -323,7 +323,7 @@ func (t *tun) reload(c *config.C, initial bool) error { } if initial { - t.snatAddr = prepareSnatAddr(t, t.l, c, routes) + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) } routeTree, err := makeRouteTree(t.l, routes, false) @@ -561,8 +561,12 @@ func (t *tun) UnsafeNetworks() []netip.Prefix { return t.unsafeNetworks } +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + func (t *tun) SNATAddress() netip.Prefix { - return t.snatAddr + return netip.Prefix{} } func (t *tun) Name() string { diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index db976d10..9ade55ac 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -22,13 +22,6 @@ type disabledTun struct { l *logrus.Logger } -func (*disabledTun) UnsafeNetworks() []netip.Prefix { - return nil -} -func (*disabledTun) SNATAddress() netip.Prefix { - return netip.Prefix{} -} - func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ vpnNetworks: vpnNetworks, @@ -59,6 +52,17 @@ func (t *disabledTun) Networks() []netip.Prefix { return t.vpnNetworks } +func (*disabledTun) UnsafeNetworks() []netip.Prefix { + return nil +} +func (*disabledTun) SNATAddress() netip.Prefix { + return netip.Prefix{} +} + +func (*disabledTun) UnsafeIPv4OriginAddress() netip.Prefix { + return netip.Prefix{} +} + func (*disabledTun) Name() string { return "disabled" } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 31289d55..e0f21769 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -86,16 +86,16 @@ type ifreqAlias6 struct { } type tun struct { - Device string - vpnNetworks []netip.Prefix - unsafeNetworks []netip.Prefix - snatAddr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - linkAddr *netroute.LinkAddr - l *logrus.Logger - devFd int + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + linkAddr *netroute.LinkAddr + l *logrus.Logger + devFd int } func (t *tun) Read(to []byte) (int, error) { @@ -414,7 +414,7 @@ func (t *tun) reload(c *config.C, initial bool) error { } if initial { - t.snatAddr = prepareSnatAddr(t, t.l, c, routes) + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) } routeTree, err := makeRouteTree(t.l, routes, false) @@ -457,8 +457,12 @@ func (t *tun) UnsafeNetworks() []netip.Prefix { return t.unsafeNetworks } +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + func (t *tun) SNATAddress() netip.Prefix { - return t.snatAddr + return netip.Prefix{} } func (t *tun) Name() string { diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 963e49c2..50ae4546 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -22,11 +22,12 @@ import ( type tun struct { io.ReadWriteCloser - vpnNetworks []netip.Prefix - unsafeNetworks []netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger } func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) { @@ -71,6 +72,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) + routeTree, err := makeRouteTree(t.l, routes, false) if err != nil { return err @@ -153,8 +156,12 @@ func (t *tun) UnsafeNetworks() []netip.Prefix { return t.unsafeNetworks } +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + func (t *tun) SNATAddress() netip.Prefix { - return t.snatAddr + return netip.Prefix{} } func (t *tun) Name() string { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 0569fcd8..19a8952b 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -47,7 +47,8 @@ type tun struct { routesFromSystem map[netip.Prefix]routing.Gateways routesFromSystemLock sync.Mutex - snatAddr netip.Prefix + snatAddr netip.Prefix + unsafeIPv4Origin netip.Prefix l *logrus.Logger } @@ -60,6 +61,10 @@ func (t *tun) UnsafeNetworks() []netip.Prefix { return t.unsafeNetworks } +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + func (t *tun) SNATAddress() netip.Prefix { return t.snatAddr } @@ -183,7 +188,8 @@ func (t *tun) reload(c *config.C, initial bool) error { } if initial { - t.snatAddr = prepareSnatAddr(t, t.l, c, routes) + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) //todo MUST be different from t.snatAddr! + t.snatAddr = prepareSnatAddr(t, t.l, c) } routeTree, err := makeRouteTree(t.l, routes, true) @@ -329,15 +335,15 @@ func (t *tun) addIPs(link netlink.Link) error { } } - if t.snatAddr.IsValid() && len(t.unsafeNetworks) == 0 { //TODO unsafe-routers should be able to snat and be snatted + if t.unsafeIPv4Origin.IsValid() { newAddrs = append(newAddrs, &netlink.Addr{ IPNet: &net.IPNet{ - IP: t.snatAddr.Addr().AsSlice(), - Mask: net.CIDRMask(t.snatAddr.Bits(), t.snatAddr.Addr().BitLen()), + IP: t.unsafeIPv4Origin.Addr().AsSlice(), + Mask: net.CIDRMask(t.unsafeIPv4Origin.Bits(), t.unsafeIPv4Origin.Addr().BitLen()), }, - Label: t.snatAddr.Addr().Zone(), + Label: t.unsafeIPv4Origin.Addr().Zone(), }) - t.l.WithField("address", t.snatAddr).Info("Adding SNAT address") + t.l.WithField("address", t.unsafeIPv4Origin).Info("Adding origin address for IPv4 unsafe_routes") } //add all new addresses @@ -431,9 +437,9 @@ func (t *tun) Activate() error { } } //TODO snat and be snatted - if t.snatAddr.IsValid() && len(t.unsafeNetworks) == 0 { - if err = t.setDefaultRoute(t.snatAddr); err != nil { - return fmt.Errorf("failed to set default route MTU for %s: %w", t.snatAddr, err) + if t.unsafeIPv4Origin.IsValid() { + if err = t.setDefaultRoute(t.unsafeIPv4Origin); err != nil { + return fmt.Errorf("failed to set default route MTU for %s: %w", t.unsafeIPv4Origin, err) } } @@ -565,10 +571,10 @@ func (t *tun) addRoutes(logErrors bool) error { } } - if len(t.unsafeNetworks) == 0 { - return nil + if t.snatAddr.IsValid() { + return t.setSnatRoute() } - return t.setSnatRoute() + return nil } func (t *tun) removeRoutes(routes []Route) { diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index e81e466c..448bede2 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -58,16 +58,16 @@ type addrLifetime struct { } type tun struct { - Device string - vpnNetworks []netip.Prefix - unsafeNetworks []netip.Prefix - snatAddr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger - f *os.File - fd int + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger + f *os.File + fd int } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) @@ -353,7 +353,7 @@ func (t *tun) reload(c *config.C, initial bool) error { } if initial { - t.snatAddr = prepareSnatAddr(t, t.l, c, routes) + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) } routeTree, err := makeRouteTree(t.l, routes, false) @@ -396,8 +396,12 @@ func (t *tun) UnsafeNetworks() []netip.Prefix { return t.unsafeNetworks } +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + func (t *tun) SNATAddress() netip.Prefix { - return t.snatAddr + return netip.Prefix{} } func (t *tun) Name() string { diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index e88bd0f4..bab929d0 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -49,16 +49,16 @@ type ifreq struct { } type tun struct { - Device string - vpnNetworks []netip.Prefix - unsafeNetworks []netip.Prefix - snatAddr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger - f *os.File - fd int + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger + f *os.File + fd int // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte } @@ -274,7 +274,7 @@ func (t *tun) reload(c *config.C, initial bool) error { } if initial { - t.snatAddr = prepareSnatAddr(t, t.l, c, routes) + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) } routeTree, err := makeRouteTree(t.l, routes, false) @@ -317,8 +317,12 @@ func (t *tun) UnsafeNetworks() []netip.Prefix { return t.unsafeNetworks } +func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + func (t *tun) SNATAddress() netip.Prefix { - return t.snatAddr + return netip.Prefix{} } func (t *tun) Name() string { diff --git a/overlay/tun_snat_test.go b/overlay/tun_snat_test.go index 0040edb4..b340eb09 100644 --- a/overlay/tun_snat_test.go +++ b/overlay/tun_snat_test.go @@ -12,11 +12,12 @@ import ( "github.com/stretchr/testify/require" ) -// mockDevice is a minimal Device implementation for testing prepareSnatAddr. +// mockDevice is a minimal Device implementation for testing prepareUnsafeOriginAddr. type mockDevice struct { networks []netip.Prefix unsafeNetworks []netip.Prefix snatAddr netip.Prefix + unsafeSnatAddr netip.Prefix } func (d *mockDevice) Read([]byte) (int, error) { return 0, nil } @@ -26,6 +27,7 @@ func (d *mockDevice) Activate() error { return func (d *mockDevice) Networks() []netip.Prefix { return d.networks } func (d *mockDevice) UnsafeNetworks() []netip.Prefix { return d.unsafeNetworks } func (d *mockDevice) SNATAddress() netip.Prefix { return d.snatAddr } +func (d *mockDevice) UnsafeIPv4OriginAddress() netip.Prefix { return d.unsafeSnatAddr } func (d *mockDevice) Name() string { return "mock" } func (d *mockDevice) RoutesFor(netip.Addr) routing.Gateways { return routing.Gateways{} } func (d *mockDevice) SupportsMultiqueue() bool { return false } @@ -40,7 +42,7 @@ func TestPrepareSnatAddr_V4Primary_NoSnat(t *testing.T) { d := &mockDevice{ networks: []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, } - result := prepareSnatAddr(d, l, c, nil) + result := prepareUnsafeOriginAddr(d, l, c, nil) assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr when device has IPv4 primary") } @@ -53,7 +55,7 @@ func TestPrepareSnatAddr_V6Primary_NoUnsafeOrRoutes(t *testing.T) { d := &mockDevice{ networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, } - result := prepareSnatAddr(d, l, c, nil) + result := prepareUnsafeOriginAddr(d, l, c, nil) assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr without IPv4 unsafe networks or routes") } @@ -67,14 +69,17 @@ func TestPrepareSnatAddr_V6Primary_WithV4Unsafe(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - result := prepareSnatAddr(d, l, c, nil) + result := prepareSnatAddr(d, l, c) require.True(t, result.IsValid(), "should assign SNAT addr") assert.True(t, result.Addr().Is4(), "SNAT addr should be IPv4") assert.True(t, result.Addr().IsLinkLocalUnicast(), "SNAT addr should be link-local") assert.Equal(t, 32, result.Bits(), "SNAT addr should be /32") + + result = prepareUnsafeOriginAddr(d, l, c, nil) + require.False(t, result.IsValid(), "no routes = no origin addr needed") } -func TestPrepareSnatAddr_V6Primary_WithV4Route(t *testing.T) { +func TestPrepareUnsafeOriginAddr_V6Primary_WithV4Route(t *testing.T) { l := logrus.New() l.SetLevel(logrus.PanicLevel) c := config.NewC(l) @@ -86,10 +91,13 @@ func TestPrepareSnatAddr_V6Primary_WithV4Route(t *testing.T) { routes := []Route{ {Cidr: netip.MustParsePrefix("10.0.0.0/8")}, } - result := prepareSnatAddr(d, l, c, routes) + result := prepareUnsafeOriginAddr(d, l, c, routes) require.True(t, result.IsValid(), "should assign SNAT addr when IPv4 route exists") assert.True(t, result.Addr().Is4()) assert.True(t, result.Addr().IsLinkLocalUnicast()) + + result = prepareSnatAddr(d, l, c) + require.False(t, result.IsValid(), "no UnsafeNetworks = no snat addr needed") } func TestPrepareSnatAddr_V6Primary_V6UnsafeOnly(t *testing.T) { @@ -102,7 +110,7 @@ func TestPrepareSnatAddr_V6Primary_V6UnsafeOnly(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("fd01::/64")}, } - result := prepareSnatAddr(d, l, c, nil) + result := prepareUnsafeOriginAddr(d, l, c, nil) assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr for IPv6-only unsafe networks") } @@ -118,7 +126,7 @@ func TestPrepareSnatAddr_ManualAddress(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - result := prepareSnatAddr(d, l, c, nil) + result := prepareSnatAddr(d, l, c) require.True(t, result.IsValid()) assert.Equal(t, netip.MustParseAddr("169.254.42.42"), result.Addr()) assert.Equal(t, 32, result.Bits()) @@ -136,7 +144,7 @@ func TestPrepareSnatAddr_InvalidManualAddress_Fallback(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - result := prepareSnatAddr(d, l, c, nil) + result := prepareSnatAddr(d, l, c) // Should fall back to auto-assignment require.True(t, result.IsValid(), "should fall back to auto-assigned address") assert.True(t, result.Addr().Is4()) @@ -155,7 +163,7 @@ func TestPrepareSnatAddr_AutoGenerated_Range(t *testing.T) { // Generate several addresses and verify they're all in the expected range for i := 0; i < 100; i++ { - result := prepareSnatAddr(d, l, c, nil) + result := prepareSnatAddr(d, l, c) require.True(t, result.IsValid()) addr := result.Addr() octets := addr.As4() diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index cb96c195..234d9336 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -17,13 +17,14 @@ import ( ) type TestTun struct { - Device string - vpnNetworks []netip.Prefix - unsafeNetworks []netip.Prefix - snatAddr netip.Prefix - Routes []Route - routeTree *bart.Table[routing.Gateways] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + snatAddr netip.Prefix + unsafeIPv4Origin netip.Prefix + Routes []Route + routeTree *bart.Table[routing.Gateways] + l *logrus.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula @@ -50,7 +51,8 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNet rxPackets: make(chan []byte, 10), TxPackets: make(chan []byte, 10), } - tt.snatAddr = prepareSnatAddr(tt, l, c, routes) + tt.unsafeIPv4Origin = prepareUnsafeOriginAddr(tt, l, c, routes) + tt.snatAddr = prepareSnatAddr(tt, tt.l, c) return tt, nil } @@ -149,6 +151,10 @@ func (t *TestTun) UnsafeNetworks() []netip.Prefix { return t.unsafeNetworks } +func (t *TestTun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + func (t *TestTun) SNATAddress() netip.Prefix { return t.snatAddr } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 4f8bb5b9..303b61d8 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -28,14 +28,14 @@ import ( const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { - Device string - vpnNetworks []netip.Prefix - unsafeNetworks []netip.Prefix - snatAddr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + unsafeIPv4Origin netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + l *logrus.Logger tun *wintun.NativeTun } @@ -106,7 +106,7 @@ func (t *winTun) reload(c *config.C, initial bool) error { } if initial { - t.snatAddr = prepareSnatAddr(t, t.l, c, routes) + t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) } routeTree, err := makeRouteTree(t.l, routes, false) @@ -140,8 +140,8 @@ func (t *winTun) Activate() error { luid := winipcfg.LUID(t.tun.LUID()) prefixes := t.vpnNetworks - if t.snatAddr.IsValid() { - prefixes = append(prefixes, t.snatAddr) + if t.unsafeIPv4Origin.IsValid() { + prefixes = append(prefixes, t.unsafeIPv4Origin) } err := luid.SetIPAddresses(prefixes) @@ -241,8 +241,12 @@ func (t *winTun) UnsafeNetworks() []netip.Prefix { return t.unsafeNetworks } +func (t *winTun) UnsafeIPv4OriginAddress() netip.Prefix { + return t.unsafeIPv4Origin +} + func (t *winTun) SNATAddress() netip.Prefix { - return t.snatAddr + return netip.Prefix{} } func (t *winTun) Name() string { diff --git a/overlay/user.go b/overlay/user.go index 1c01dd1c..87eee029 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -43,6 +43,9 @@ func (d *UserDevice) UnsafeNetworks() []netip.Prefix { func (d *UserDevice) SNATAddress() netip.Prefix { return netip.Prefix{} } +func (d *UserDevice) UnsafeIPv4OriginAddress() netip.Prefix { + return netip.Prefix{} +} func (d *UserDevice) Activate() error { return nil diff --git a/snat_test.go b/snat_test.go index b6e2a116..83dfc6d9 100644 --- a/snat_test.go +++ b/snat_test.go @@ -335,7 +335,7 @@ func TestFirewall_IdentifyNetworkType_SNATPeer(t *testing.T) { RemoteAddr: netip.MustParseAddr("10.0.0.1"), LocalAddr: netip.MustParseAddr("192.168.1.1"), } - assert.Equal(t, NetworkTypeUncheckedSNATPeer, fw.identifyNetworkType(h, fp)) + assert.Equal(t, NetworkTypeUncheckedSNATPeer, fw.identifyRemoteNetworkType(h, fp)) }) t.Run("v4 packet from v4 host is not snat peer", func(t *testing.T) { @@ -345,7 +345,7 @@ func TestFirewall_IdentifyNetworkType_SNATPeer(t *testing.T) { RemoteAddr: netip.MustParseAddr("10.0.0.1"), LocalAddr: netip.MustParseAddr("192.168.1.1"), } - assert.Equal(t, NetworkTypeVPN, fw.identifyNetworkType(h, fp)) + assert.Equal(t, NetworkTypeVPN, fw.identifyRemoteNetworkType(h, fp)) }) t.Run("v6 packet from v6 host is VPN", func(t *testing.T) { @@ -355,7 +355,7 @@ func TestFirewall_IdentifyNetworkType_SNATPeer(t *testing.T) { RemoteAddr: netip.MustParseAddr("fd00::1"), LocalAddr: netip.MustParseAddr("fd00::2"), } - assert.Equal(t, NetworkTypeVPN, fw.identifyNetworkType(h, fp)) + assert.Equal(t, NetworkTypeVPN, fw.identifyRemoteNetworkType(h, fp)) }) t.Run("mismatched v4 from v4 host is invalid", func(t *testing.T) { @@ -365,39 +365,40 @@ func TestFirewall_IdentifyNetworkType_SNATPeer(t *testing.T) { RemoteAddr: netip.MustParseAddr("10.0.0.99"), LocalAddr: netip.MustParseAddr("192.168.1.1"), } - assert.Equal(t, NetworkTypeInvalidPeer, fw.identifyNetworkType(h, fp)) + assert.Equal(t, NetworkTypeInvalidPeer, fw.identifyRemoteNetworkType(h, fp)) }) } func TestFirewall_AllowNetworkType_SNAT(t *testing.T) { - t.Run("snat peer allowed with snat addr", func(t *testing.T) { - fw := &Firewall{snatAddr: netip.MustParseAddr("169.254.55.96")} - assert.NoError(t, fw.allowNetworkType(NetworkTypeUncheckedSNATPeer)) - }) - - t.Run("snat peer rejected without snat addr", func(t *testing.T) { - fw := &Firewall{} - assert.ErrorIs(t, fw.allowNetworkType(NetworkTypeUncheckedSNATPeer), ErrInvalidRemoteIP) - }) + //todo fix! + //t.Run("snat peer allowed with snat addr", func(t *testing.T) { + // fw := &Firewall{snatAddr: netip.MustParseAddr("169.254.55.96")} + // assert.NoError(t, fw.allowRemoteNetworkType(NetworkTypeUncheckedSNATPeer, fp)) + //}) + // + //t.Run("snat peer rejected without snat addr", func(t *testing.T) { + // fw := &Firewall{} + // assert.ErrorIs(t, fw.allowRemoteNetworkType(NetworkTypeUncheckedSNATPeer, fp), ErrInvalidRemoteIP) + //}) t.Run("vpn always allowed", func(t *testing.T) { fw := &Firewall{} - assert.NoError(t, fw.allowNetworkType(NetworkTypeVPN)) + assert.NoError(t, fw.allowRemoteNetworkType(NetworkTypeVPN, firewall.Packet{})) }) t.Run("unsafe always allowed", func(t *testing.T) { fw := &Firewall{} - assert.NoError(t, fw.allowNetworkType(NetworkTypeUnsafe)) + assert.NoError(t, fw.allowRemoteNetworkType(NetworkTypeUnsafe, firewall.Packet{})) }) t.Run("invalid peer rejected", func(t *testing.T) { fw := &Firewall{} - assert.ErrorIs(t, fw.allowNetworkType(NetworkTypeInvalidPeer), ErrInvalidRemoteIP) + assert.ErrorIs(t, fw.allowRemoteNetworkType(NetworkTypeInvalidPeer, firewall.Packet{}), ErrInvalidRemoteIP) }) t.Run("vpn peer rejected", func(t *testing.T) { fw := &Firewall{} - assert.ErrorIs(t, fw.allowNetworkType(NetworkTypeVPNPeer), ErrPeerRejected) + assert.ErrorIs(t, fw.allowRemoteNetworkType(NetworkTypeVPNPeer, firewall.Packet{}), ErrPeerRejected) }) } @@ -906,7 +907,7 @@ func TestFirewall_ApplySnat_MixedStackRejected(t *testing.T) { }} err := fw.applySnat(pkt, &fp, cn, h) - assert.ErrorIs(t, err, ErrCannotSNAT) + require.Error(t, err, ErrCannotSNAT) assert.Equal(t, canonicalUDPTest, pkt, "packet bytes must be unmodified on error") }) } @@ -1164,7 +1165,7 @@ func TestFirewall_Drop_SNATLocalAddrNotRoutable(t *testing.T) { func TestFirewall_Drop_NoSnatAddrRejectsV6Peer(t *testing.T) { // Firewall has no snatAddr configured. An IPv6-only peer sends IPv4 traffic. - // allowNetworkType(UncheckedSNATPeer) should reject with ErrInvalidRemoteIP. + // allowRemoteNetworkType(UncheckedSNATPeer) should reject with ErrInvalidRemoteIP. l := logrus.New() l.SetLevel(logrus.PanicLevel) @@ -1277,8 +1278,8 @@ func TestFirewall_Drop_IPv4HostNotSNATted(t *testing.T) { assert.Equal(t, canonicalUDPV4Traffic, pkt, "packet must not be rewritten when peer is rejected") }) - t.Run("identifyNetworkType classifies v4 peer correctly", func(t *testing.T) { - // Directly verify that identifyNetworkType returns the right type for + t.Run("identifyRemoteNetworkType classifies v4 peer correctly", func(t *testing.T) { + // Directly verify that identifyRemoteNetworkType returns the right type for // an IPv4 peer (not UncheckedSNATPeer). fw := &Firewall{snatAddr: snatAddr} @@ -1288,12 +1289,12 @@ func TestFirewall_Drop_IPv4HostNotSNATted(t *testing.T) { RemoteAddr: netip.MustParseAddr("10.128.0.2"), LocalAddr: netip.MustParseAddr("192.168.1.1"), } - nwType := fw.identifyNetworkType(h, fp) + nwType := fw.identifyRemoteNetworkType(h, fp) assert.Equal(t, NetworkTypeVPN, nwType, "v4 peer using its own VPN addr should be NetworkTypeVPN") assert.NotEqual(t, NetworkTypeUncheckedSNATPeer, nwType, "must NOT be classified as SNAT peer") }) - t.Run("identifyNetworkType v4 peer with mismatched source", func(t *testing.T) { + t.Run("identifyRemoteNetworkType v4 peer with mismatched source", func(t *testing.T) { // v4 host sends with a source IP that doesn't match its VPN addr fw := &Firewall{snatAddr: snatAddr} @@ -1302,7 +1303,7 @@ func TestFirewall_Drop_IPv4HostNotSNATted(t *testing.T) { RemoteAddr: netip.MustParseAddr("10.0.0.99"), // Not the peer's VPN addr LocalAddr: netip.MustParseAddr("192.168.1.1"), } - nwType := fw.identifyNetworkType(h, fp) + nwType := fw.identifyRemoteNetworkType(h, fp) assert.Equal(t, NetworkTypeInvalidPeer, nwType, "v4 peer with mismatched source should be InvalidPeer") assert.NotEqual(t, NetworkTypeUncheckedSNATPeer, nwType, "must NOT be classified as SNAT peer") }) diff --git a/test/tun.go b/test/tun.go index 37728f6c..e967568b 100644 --- a/test/tun.go +++ b/test/tun.go @@ -18,6 +18,10 @@ func (NoopTun) SNATAddress() netip.Prefix { return netip.Prefix{} } +func (NoopTun) UnsafeIPv4OriginAddress() netip.Prefix { + return netip.Prefix{} +} + func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways { return routing.Gateways{} } From 8f1d384eb82ae077c991f3ad149819b61cc1387d Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 19 Feb 2026 14:55:49 -0600 Subject: [PATCH 10/31] think really hard --- firewall.go | 39 +++++++++++---------------------------- 1 file changed, 11 insertions(+), 28 deletions(-) diff --git a/firewall.go b/firewall.go index b2d15741..a5f69570 100644 --- a/firewall.go +++ b/firewall.go @@ -353,7 +353,6 @@ func (f *Firewall) SetSNATAddressFromInterface(i *Interface) { //todo should snatted conntracks get expired out? Probably not needed until if/when we allow reload f.snatAddr = i.inside.SNATAddress().Addr() f.unsafeIPv4Origin = i.inside.UnsafeIPv4OriginAddress().Addr() - //f.routableNetworks.Insert(i.inside.UnsafeIPv4OriginAddress()) //todo is this the right idea? } func (f *Firewall) ShouldUnSNAT(fp *firewall.Packet) bool { @@ -525,6 +524,9 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo if !f.snatAddr.IsValid() { return ErrCannotSNAT } + if f.snatAddr == fp.LocalAddr { //a packet that came from UDP (incoming) should never ever have our snat address on it + return ErrSNATIdentityMismatch + } if c.snat.Valid() { //old flow: make sure it came from the right place if !slices.Contains(hostinfo.vpnAddrs, c.snat.SrcVpnIp) { @@ -562,7 +564,6 @@ func (f *Firewall) identifyRemoteNetworkType(h *HostInfo, fp firewall.Packet) Ne return NetworkTypeVPN } //else, fallthrough } else if nwType, ok := h.networks.Lookup(fp.RemoteAddr); ok { - //todo check for if fp.RemoteAddr is our f.snatAddr here too? Does that need a special case? return nwType //will return NetworkTypeVPN or NetworkTypeUnsafe } @@ -581,7 +582,7 @@ func (f *Firewall) allowRemoteNetworkType(nwType NetworkType, fp firewall.Packet case NetworkTypeInvalidPeer: return ErrInvalidRemoteIP case NetworkTypeVPNPeer: - //todo we might need a specialSnatMode case in here to handle routers with v4 addresses when we don't also have a v4 address? + //one day we might need a specialSnatMode case in here to handle routers with v4 addresses when we don't also have a v4 address? return ErrPeerRejected // reject for now, one day this may have different FW rules case NetworkTypeUnsafe: return nil // nothing special, one day this may have different FW rules @@ -589,8 +590,11 @@ func (f *Firewall) allowRemoteNetworkType(nwType NetworkType, fp firewall.Packet if f.unsafeIPv4Origin.IsValid() && fp.LocalAddr == f.unsafeIPv4Origin { return nil //the client case } - if f.snatAddr.IsValid() { //todo - return nil //todo is this enough? + if f.snatAddr.IsValid() { + if fp.RemoteAddr == f.snatAddr { + return ErrInvalidRemoteIP //we should never get a packet with our SNAT addr as the destination, or "from" our SNAT addr + } + return nil } else { return ErrInvalidRemoteIP } @@ -603,35 +607,14 @@ func (f *Firewall) willingToHandleLocalAddr(incoming bool, fp firewall.Packet, r if f.routableNetworks.Contains(fp.LocalAddr) { return nil //easy, this should handle NetworkTypeVPN in all cases, and NetworkTypeUnsafe on the router side } - if incoming { //at least for now, reject all traffic other than what we've already decided is routable + if incoming { //at least for now, reject all traffic other than what we've already decided is locally routable return ErrInvalidLocalIP } - //now, all traffic is outgoing. Outgoing traffic to these types is not required to be considered inbound-routable - //todo is this right??? can/should these rules be tighter? + //below this line, all traffic is outgoing. Outgoing traffic to NetworkTypeUnsafe is not required to be considered inbound-routable if remoteNwType == NetworkTypeUnsafe { return nil } - //if remoteNwType == NetworkTypeUncheckedSNATPeer { - // return nil - //} - - //todo - - ////watch out, when incoming, this function decides if we will deliver a packet locally - ////when outgoing, much less important, it just decides if we're willing to tx - //switch remoteNwType { - //// we never want to accept unconntracked inbound traffic from these network types, but outbound is okay. - //// It's the recipient's job to validate and accept or deny the packet. - //case NetworkTypeUncheckedSNATPeer, NetworkTypeUnsafe: - // //NetworkTypeUnsafe needed here to allow inbound from an unsafe-router - // if incoming { - // return ErrInvalidLocalIP - // } - // return nil - //default: - // return ErrInvalidLocalIP - //} return ErrInvalidLocalIP } From dd786cddf1d417b7046182086bc7f00baa895e42 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 19 Feb 2026 14:59:32 -0600 Subject: [PATCH 11/31] appease CI --- snat_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/snat_test.go b/snat_test.go index 83dfc6d9..40a54f6b 100644 --- a/snat_test.go +++ b/snat_test.go @@ -841,7 +841,7 @@ func TestFirewall_ApplySnat_CrossHostHijack(t *testing.T) { hB := &HostInfo{vpnAddrs: []netip.Addr{hostB}} err := fw.applySnat(pkt, &fp, cn, hB) - assert.ErrorIs(t, err, ErrSNATIdentityMismatch) + require.ErrorIs(t, err, ErrSNATIdentityMismatch) assert.Equal(t, canonicalUDPHijack, pkt, "packet bytes must be unmodified after identity mismatch") } @@ -907,7 +907,7 @@ func TestFirewall_ApplySnat_MixedStackRejected(t *testing.T) { }} err := fw.applySnat(pkt, &fp, cn, h) - require.Error(t, err, ErrCannotSNAT) + require.ErrorIs(t, err, ErrCannotSNAT) assert.Equal(t, canonicalUDPTest, pkt, "packet bytes must be unmodified on error") }) } @@ -1101,7 +1101,7 @@ func TestFirewall_Drop_FirewallBlocksSNAT(t *testing.T) { cp := cert.NewCAPool() err := fw.Drop(fp, pkt, true, h, cp, nil) - assert.ErrorIs(t, err, ErrNoMatchingRule, "firewall should block SNAT-eligible traffic that doesn't match rules") + require.ErrorIs(t, err, ErrNoMatchingRule, "firewall should block SNAT-eligible traffic that doesn't match rules") assert.Equal(t, canonicalUDPBlocked, pkt, "packet must not be rewritten when firewall blocks it") } @@ -1274,7 +1274,7 @@ func TestFirewall_Drop_IPv4HostNotSNATted(t *testing.T) { cp := cert.NewCAPool() err := fw.Drop(fp, pkt, true, h, cp, nil) - assert.ErrorIs(t, err, ErrPeerRejected, "IPv4 peer should be rejected as VPNPeer, not treated as SNAT") + require.Error(t, err, ErrPeerRejected, "IPv4 peer should be rejected as VPNPeer, not treated as SNAT") assert.Equal(t, canonicalUDPV4Traffic, pkt, "packet must not be rewritten when peer is rejected") }) From 879b77d07636ef10bb570effce022552d7c6cca3 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 19 Feb 2026 15:57:02 -0600 Subject: [PATCH 12/31] oops --- firewall.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firewall.go b/firewall.go index a5f69570..685c0cee 100644 --- a/firewall.go +++ b/firewall.go @@ -510,8 +510,8 @@ func (f *Firewall) findUsableSNATPort(fp *firewall.Packet, c *conn) error { } //increment and retry. There's probably better strategies out there fp.RemotePort++ - if fp.RemotePort < 0x7ff { - fp.RemotePort += 0x7ff // keep it ephemeral for now + if fp.RemotePort < 0x7fff { + fp.RemotePort += 0x7fff // keep it ephemeral for now } } From ae1b501468b5a2029e2b853ed396dbdfa0ea927d Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 20 Feb 2026 11:29:46 -0600 Subject: [PATCH 13/31] oops --- cert/cert_v2.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 87d1ec11..09f7bd79 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -439,11 +439,7 @@ func (c *certificateV2) validate() error { if !hasV6Networks { return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network) } - } else if network.Addr().Is4() { - if !hasV4Networks { - //return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network) - } - } + } // as long as we have any IP address, IPv4 UnsafeNetworks are allowed } } From 2319eb94921dc67c9755842085ddc6a17964f314 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 20 Feb 2026 11:57:43 -0600 Subject: [PATCH 14/31] remove notes --- SNAT.md | 102 -------------------------------------------------------- 1 file changed, 102 deletions(-) delete mode 100644 SNAT.md diff --git a/SNAT.md b/SNAT.md deleted file mode 100644 index 230e44a1..00000000 --- a/SNAT.md +++ /dev/null @@ -1,102 +0,0 @@ -# Don't merge me - -# Accessing IPv4-only UnsafeNetworks via an IPv6-only overlay - -## Background -Nebula is an VPN-like connectivity solution. It provides what we call an "overlay network", -an additional IP-addressed network on top of one-or-more traditional "underlay networks". -As long as two devices with Nebula certificates (credentials, essentially signed keypairs with metadata) can find each other -and exchange traffic via a common underlay network (often this is the Internet), -they will also be able to exchange traffic securely via a point-to-point, encrypted, authenticated tunnel. - -Typically, all Nebula traffic is strongly associated with the Nebula certificate of the sender -(that is, the source IP of all packets matches the IP listed in the sender's certificate). -However, it is useful to be able to bend this rule. That is why there is another field in the Nebula certificate, named UnsafeNetworks, -which lists the network prefixes that the host bearing this certificate is allowed to "speak for". - -## Problem Statement -We want IPv6-only overlay networks to be able to carry IPv4 traffic to reach off-overlay hosts via UnsafeNetworks - -### Scenario - -To illustrate this scenario, we will define 3 hosts: -* a Phone, running Nebula, assigned the overlay IP fd01::AA/64. It has an undefined underlay, but we assert that it always has working IPv4 OR IPv6 connectivity to Router. -* a Router, running Nebula, assigned the overlay IP fd01::55/64. It has a stable underlay link that Phone can always reach. -* a Printer, which cannot run Nebula, and is only capable of IPv4 communication. It has a direct link to Router, but the Phone cannot reach it directly. - -You, the User, wish to use your Phone to print out something on the Printer while you're away from home. How can we make this possible with your IPv6-only Nebula overlay? -In particular, your Phone may connect to any cellular or public WiFi network, and we cannot control the IP address it will be assigned. If you MUST print, an IP conflict is not acceptable. -Therefore, we cannot simply dismiss this problem by suggesting that you assign a small IPv4 network within your overlay. Sure, it probably works, and in this toy scenario, the odds of a conflict are pretty small. But it scales very poorly. What if a whole company needs to use this printer (or perhaps a less contrived need?) -We can do better. - -## Solution - -* Even though Phone and Router lack IPv4 assignments, we can still put V4 addresses on their tun devices. -* Each overlay host who wishes to use this feature shall select (or configure?) an assignment within 169.254.0.0/16, the IPv4 link-local range - * this is a pretty small space, but it confines the region of IP conflict to a much smaller domain. And, because overlay hosts will never dial one another with this address, cryptographic verification of it via the certificate is less important. - * On Phone, Nebula will configure an unsafe_route to the Printer using this address. Because it is a device route, we do not need to tell the operating system the address of the next hop (no `via`) - * On Router, Nebula will use this address to masquerade packets from Phone. You'll see! -* Let's walk through setting up a TCP session between Phone and Printer in this scheme: - * Phone sends SYN to the printer's underlay IPv4 address - * This packet lands on Phone's Nebula tun device - * Nebula locates Router as the destination for this packet, as defined in `tun.unsafe_routes` - * Nebula checks the packet against the outbound chain: - * the destination IP of Printer is listed in Router's UnsafeNetworks, so that check will pass - * Phone's source IP is not listed in any certificate, but because the destination address is of `NetworkTypeUnsafe` and this is an outgoing flow, we keep going - * Actual outbound firewall rules get checked, assume they pass - * conntrack entry created to permit replies - * Phone encrypts the packet and sends it to Router - * Router gets the packet from Phone, and decrypts it. It is passed to the Nebula firewall for evaluation: - * `firewall.Drop()` on the Router's Nebula inbound rules: - * Because Router is configured to allow SNAT, and this packet is an IPv4 packet from a IPv6-only host, the firewall module enters "snat mode" (`TODO explain?`) - * This is a new flow, so the conntrack lookup for it fails - * `firewall.identifyNetworkType()` - * identify what "kind" of remote address (this is the inbound firewall, so the remote address is the packet's src IP) we've been given - * `NetworkTypeVPN`, for example is a remote address that matches the peer's certificate - * In this case, because the traffic is IPv4 traffic flowing from an IPv6-only host, and we've opted into supporting SNAT, this traffic is marked as `NetworkTypeUncheckedSNATPeer` - * `firewall.allowNetworkType()` will allow `NetworkTypeUncheckedSNATPeer` traffic to proceed because we have opted into SNAT - * `firewall.willingToHandleLocalAddr()` now needs to check if we're willing to accept the destination address - * Because this traffic is addressed to a destination listed in our UnsafeNetworks, it's considered "routable" and passes this check - * Nebula's firewall rules are evaluated as normal. In particular, the `cidr` parameter will be checked against the IPv4 address, NOT the IPv6 address in the Phone's certificate - * @Nate I think this is "correct", but could be a source of footgun - * Let's assume the Nebula rules accept the traffic - * We create a conntrack entry for this flow - * We do not want to transmit with the IPv4 source IP we got from Phone. We don't want the Phone's IP assignments (in this scheme) to enter the network-space on Router at all. - * To this end, we rewrite the source address (and port, if needed) to our own pre-selected IPv4 link-local address. This address will never actually leave the machine, but we need it so return traffic can be routed back to the nebula tun on Router - * Replace source IP with "Router's SNAT address" - * Look in our conntrack table, and ensure we do not already have a flow that matches this srcip/srcport/dstip/dstport/proto tuple - * if we do, increment srcport until we find an empty slot. Only use ephemeral ports. This gives 0x7ff flows per dstip/dstport/proto tuple, which ought to be plenty. - * Record the original srcip/srcport as part of the conntrack data for later - * Fix checksums - * Nebula writes the rewritten packet to Router's tun - * netfilter picks up the packet. In this example, Router is using `iptables`. A rule in the `nat` table similar to `-A POSTROUTING -d PRINTER_UNDERLAY_IP_HERE/32 -j MASQUERADE` is hit - * This ensures that "Router's SNAT address" never actually leaves Router. - * The packet leaves Router, and hits Printer - * Printer gleefully accepts the SYN from Router, and replies with an ACK - * iptables on Router de-masquerades the packet, and delivers it to the Nebula tun - * Nebula reads the packet off the tun. Because it came from the tun, and not UDP, remember that this is considered "inside" traffic and will be evaluated as "outbound" traffic by Nebula. - * Because this is inside traffic, it needs to be associated with a HostInfo before we can pass it to the firewall. - * Check that the packet is addressed to the "Router's SNAT address". If so, attempt to un-SNAT by "peeking" into conntrack - * If a Router needs to speak to _another_ Router with v4-in-v6 unsafe_routes like this, it _must_ use a distinct address from the "Router's SNAT address" - * the easy way on Linux to assure this is to set a route for the "Router SNAT address" to the Nebula tun, but not actually assign the address - * The "peek" into conntrack succeeds, and we find everything we need to rewrite the packet for transmission to Phone, as well as Phone's overlay IP, which lets us locate Phone's HostInfo - * The packet is rewritten, replacing the destination address/port to match the ones Phone expects - * checksums corrected - * Check the Nebula firewall, and see that we have a valid conntrack entry (wow!) - * we could _technically_ skip this check, but I dislike not passing all traffic we intend to accept into `firewall.Drop()`. The second conntrack locking-lookup does suck. There's room for improvement here. - * The traffic is accepted, encrypted, and sent back to Phone - * Phone gets the packet from Router, decrypts it, checks the firewall - * we have a conntrack entry for this flow, so the firewall accepts it, and delivers it to the tun - * Both sides now have a nice conntrack entry, and traffic should continue to flow uninterrupted until it expires - -This conntrack entry technically creates a risk though. Let's examine that. -The Phone will accept inbound traffic matching the conntrack spec from any Router-like host authorized to speak for that UnsafeRoute, not just Router. In theory, this is desireable, and the risk is mitigated by accepting/trusting Nebula's certificate model. -There's a good chance that if you "switch" from one Router to another, you'll lose your session on your Printer-like host. Such is life under NAT! - -Can the Router be exploited somehow? -* an attacker that shares a network with Printer would be able to spoof traffic as if they are Printer. This is the same risk as UnsafeNetworks today. -* an attacker on the overlay would have their traffic evaluated as inbound - * if they try to tx on the same source IP as Phone, SNAT will assign a different port - * if they try to send inbound traffic that matches the un-masqueraded traffic iptables would have delivered - * conntrack will accept the packet, but before we finish firewalling and return, is the applySnat step - * this will fail because the hostinfo that sent the packet does not contain the vpnip that is associated with the snat entry \ No newline at end of file From e77f49abb85aa2fb3d1d1c1d13c6b4d9af70b885 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 20 Feb 2026 11:59:51 -0600 Subject: [PATCH 15/31] fix test --- snat_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/snat_test.go b/snat_test.go index 40a54f6b..5c3538ad 100644 --- a/snat_test.go +++ b/snat_test.go @@ -461,7 +461,7 @@ func TestFirewall_FindUsableSNATPort(t *testing.T) { } fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) - // Fill all 0x7ff ports + // Fill all ports baseFP := firewall.Packet{ LocalAddr: netip.MustParseAddr("192.168.1.1"), RemoteAddr: snatAddr, @@ -469,16 +469,16 @@ func TestFirewall_FindUsableSNATPort(t *testing.T) { Protocol: firewall.ProtoUDP, } fw.Conntrack.Lock() - for i := 0; i < 0x7ff; i++ { + for i := 0; i < 65535; i++ { fp := baseFP - fp.RemotePort = uint16(0x7ff + i) + fp.RemotePort = uint16(i) fw.Conntrack.Conns[fp] = &conn{} } fw.Conntrack.Unlock() - // Try to find a port starting from 0x7ff + // Try to find a port starting from 0x8000 fp := baseFP - fp.RemotePort = 0x7ff + fp.RemotePort = 0x8000 cn := &conn{} err := fw.findUsableSNATPort(&fp, cn) assert.ErrorIs(t, err, ErrCannotSNAT) From a881e4fdf854ece6b8b6337756678772630249d7 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 20 Feb 2026 12:10:14 -0600 Subject: [PATCH 16/31] fix test2 --- snat_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/snat_test.go b/snat_test.go index 5c3538ad..8fb43a2b 100644 --- a/snat_test.go +++ b/snat_test.go @@ -1274,7 +1274,7 @@ func TestFirewall_Drop_IPv4HostNotSNATted(t *testing.T) { cp := cert.NewCAPool() err := fw.Drop(fp, pkt, true, h, cp, nil) - require.Error(t, err, ErrPeerRejected, "IPv4 peer should be rejected as VPNPeer, not treated as SNAT") + require.ErrorIs(t, err, ErrPeerRejected, "IPv4 peer should be rejected as VPNPeer, not treated as SNAT") assert.Equal(t, canonicalUDPV4Traffic, pkt, "packet must not be rewritten when peer is rejected") }) From 34e817742bfd0afc9c19982ee71852f0ab04e88d Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 26 Feb 2026 10:26:16 -0600 Subject: [PATCH 17/31] thanks clod! --- overlay/tun_android.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/overlay/tun_android.go b/overlay/tun_android.go index d1434890..c9213cc7 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -97,7 +97,7 @@ func (t *tun) Networks() []netip.Prefix { } func (t *tun) UnsafeNetworks() []netip.Prefix { - return t.UnsafeNetworks() + return t.unsafeNetworks } func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix { From 009a4698a09c66edb7a2dfda0e6042fd7b34c5ae Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 26 Feb 2026 10:31:18 -0600 Subject: [PATCH 18/31] thanks clod! --- firewall.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/firewall.go b/firewall.go index 685c0cee..bfbc47e1 100644 --- a/firewall.go +++ b/firewall.go @@ -496,11 +496,12 @@ func rewritePacket(data []byte, fp *firewall.Packet, oldIP netip.AddrPort, newIP } func (f *Firewall) findUsableSNATPort(fp *firewall.Packet, c *conn) error { + const halfThePorts = 0x7fff oldPort := fp.RemotePort conntrack := f.Conntrack conntrack.Lock() defer conntrack.Unlock() - for numPortsChecked := 0; numPortsChecked < 0x7ff; numPortsChecked++ { + for numPortsChecked := 0; numPortsChecked < halfThePorts; numPortsChecked++ { _, ok := conntrack.Conns[*fp] if !ok { //yay, we can use this port @@ -510,8 +511,8 @@ func (f *Firewall) findUsableSNATPort(fp *firewall.Packet, c *conn) error { } //increment and retry. There's probably better strategies out there fp.RemotePort++ - if fp.RemotePort < 0x7fff { - fp.RemotePort += 0x7fff // keep it ephemeral for now + if fp.RemotePort < halfThePorts { + fp.RemotePort += halfThePorts // keep it ephemeral for now } } From f7dd3c0ce4aed0d2e391d9b5674585a99c2eb883 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 26 Feb 2026 10:46:06 -0600 Subject: [PATCH 19/31] moar test --- overlay/tun.go | 18 ++++++++++++++---- overlay/tun_test.go | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 overlay/tun_test.go diff --git a/overlay/tun.go b/overlay/tun.go index 8ca6f537..d18ce123 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -3,6 +3,7 @@ package overlay import ( "crypto/rand" "fmt" + "io" "net" "net/netip" @@ -131,9 +132,18 @@ func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, er return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest) } -func genLinkLocal() netip.Prefix { +// genLinkLocal generates a random IPv4 link-local address. +// If randomizer is nil, it uses rand.Reader to find two random bytes +func genLinkLocal(randomizer io.Reader) netip.Prefix { + if randomizer == nil { + randomizer = rand.Reader + } octets := []byte{169, 254, 0, 0} - _, _ = rand.Read(octets[2:4]) + _, _ = randomizer.Read(octets[2:4]) + return coerceLinkLocal(octets) +} + +func coerceLinkLocal(octets []byte) netip.Prefix { if octets[3] == 0 { octets[3] = 1 //please no .0 addresses } else if octets[2] == 255 && octets[3] == 255 { @@ -171,7 +181,7 @@ func prepareUnsafeOriginAddr(d Device, l *logrus.Logger, c *config.C, routes []R return netip.PrefixFrom(out, 32) } } - return genLinkLocal() + return genLinkLocal(nil) } // prepareSnatAddr provides the address that an IPv6-only unsafe router should use to SNAT traffic before handing it to the operating system @@ -201,5 +211,5 @@ func prepareSnatAddr(d Device, l *logrus.Logger, c *config.C) netip.Prefix { return netip.PrefixFrom(out, 32) } } - return genLinkLocal() + return genLinkLocal(nil) } diff --git a/overlay/tun_test.go b/overlay/tun_test.go new file mode 100644 index 00000000..1e247aad --- /dev/null +++ b/overlay/tun_test.go @@ -0,0 +1,37 @@ +package overlay + +import ( + "bytes" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLinkLocal(t *testing.T) { + r := bytes.NewReader([]byte{42, 99}) + result := genLinkLocal(r) + assert.Equal(t, netip.MustParsePrefix("169.254.42.99/32"), result, "genLinkLocal with a deterministic randomizer") + + result = genLinkLocal(nil) + assert.True(t, result.IsValid(), "genLinkLocal with nil randomizer should be valid") + assert.True(t, result.Addr().IsLinkLocalUnicast(), "genLinkLocal with nil randomizer should be link-local") + + result = coerceLinkLocal([]byte{169, 254, 100, 50}) + assert.Equal(t, netip.MustParsePrefix("169.254.100.50/32"), result, "coerceLinkLocal should pass through normal values") + + result = coerceLinkLocal([]byte{169, 254, 0, 0}) + assert.Equal(t, netip.MustParsePrefix("169.254.0.1/32"), result, "coerceLinkLocal should bump .0 last octet to .1") + + result = coerceLinkLocal([]byte{169, 254, 255, 255}) + assert.Equal(t, netip.MustParsePrefix("169.254.255.254/32"), result, "coerceLinkLocal should bump broadcast 255.255 to 255.254") + + result = coerceLinkLocal([]byte{169, 254, 0, 1}) + assert.Equal(t, netip.MustParsePrefix("169.254.0.1/32"), result, "coerceLinkLocal should leave .1 last octet unchanged") + + result = coerceLinkLocal([]byte{169, 254, 255, 254}) + assert.Equal(t, netip.MustParsePrefix("169.254.255.254/32"), result, "coerceLinkLocal should leave 255.254 unchanged") + + result = coerceLinkLocal([]byte{169, 254, 255, 100}) + assert.Equal(t, netip.MustParsePrefix("169.254.255.100/32"), result, "coerceLinkLocal should leave 255.100 unchanged") +} From e4897b07c9e17b140b01e4e92a6896d7bca09e48 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 26 Feb 2026 10:49:05 -0600 Subject: [PATCH 20/31] leftover cruft from merging --- firewall.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/firewall.go b/firewall.go index bfbc47e1..9eb2a670 100644 --- a/firewall.go +++ b/firewall.go @@ -908,13 +908,6 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer var port int32 - if p.Protocol == firewall.ProtoICMP { - // port numbers are re-used for connection tracking and SNAT, - // but we don't want to actually filter on them for ICMP - // ICMP6 is omitted because we don't attempt to parse code/identifier/etc out of ICMP6 - return fp[firewall.PortAny].match(p, c, caPool) - } - if p.Fragment { port = firewall.PortFragment } else if incoming { From 629700fbb6a5279415869f3fc147c5b27cde14bc Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 26 Feb 2026 10:58:10 -0600 Subject: [PATCH 21/31] feedback --- firewall.go | 3 +- firewall_test.go | 92 ++++++++++++++++++++++++------------------------ snat_test.go | 44 +++++++++++------------ 3 files changed, 69 insertions(+), 70 deletions(-) diff --git a/firewall.go b/firewall.go index 9eb2a670..4d8d7b3b 100644 --- a/firewall.go +++ b/firewall.go @@ -165,7 +165,7 @@ type firewallLocalCIDR struct { // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // The certificate provided should be the highest version loaded in memory. -func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate, snatAddr netip.Addr) *Firewall { +func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration var tmin, tmax time.Duration @@ -241,7 +241,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), certificate, - netip.Addr{}, ) fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) diff --git a/firewall_test.go b/firewall_test.go index c42cad65..4df6eadd 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -21,7 +21,7 @@ import ( func TestNewFirewall(t *testing.T) { l := test.NewLogger() c := &dummyCert{} - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) conntrack := fw.Conntrack assert.NotNil(t, conntrack) assert.NotNil(t, conntrack.Conns) @@ -36,23 +36,23 @@ func TestNewFirewall(t *testing.T) { assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c, netip.Addr{}) + fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c, netip.Addr{}) + fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c, netip.Addr{}) + fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) } @@ -63,7 +63,7 @@ func TestFirewall_AddRule(t *testing.T) { l.SetOutput(ob) c := &dummyCert{} - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) @@ -79,56 +79,56 @@ func TestFirewall_AddRule(t *testing.T) { assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) //no matter what port is given for icmp, it should end up as "any" assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any) assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups) assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1") - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) anyIp, err := netip.ParsePrefix("0.0.0.0/0") require.NoError(t, err) @@ -139,7 +139,7 @@ func TestFirewall_AddRule(t *testing.T) { table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9")) assert.False(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) anyIp6, err := netip.ParsePrefix("::/0") require.NoError(t, err) @@ -150,28 +150,28 @@ func TestFirewall_AddRule(t *testing.T) { table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1")) assert.False(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", "")) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", "")) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) } @@ -208,7 +208,7 @@ func TestFirewall_Drop(t *testing.T) { } h.buildNetworks(myVpnNetworksTable, &c) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -227,27 +227,27 @@ func TestFirewall_Drop(t *testing.T) { p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) @@ -287,7 +287,7 @@ func TestFirewall_DropV6(t *testing.T) { } h.buildNetworks(myVpnNetworksTable, &c) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -306,27 +306,27 @@ func TestFirewall_DropV6(t *testing.T) { p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) @@ -532,7 +532,7 @@ func TestFirewall_Drop2(t *testing.T) { } h1.buildNetworks(myVpnNetworksTable, c1.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -612,7 +612,7 @@ func TestFirewall_Drop3(t *testing.T) { } h3.buildNetworks(myVpnNetworksTable, c3.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) cp := cert.NewCAPool() @@ -627,7 +627,7 @@ func TestFirewall_Drop3(t *testing.T) { assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule) // Test a remote address match - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil)) } @@ -664,7 +664,7 @@ func TestFirewall_Drop3V6(t *testing.T) { h.buildNetworks(myVpnNetworksTable, c.Certificate) // Test a remote address match - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) cp := cert.NewCAPool() require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) @@ -704,7 +704,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { } h.buildNetworks(myVpnNetworksTable, c.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -717,7 +717,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw := fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -726,7 +726,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw = fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -771,7 +771,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { } t.Run("ICMP allowed", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", "")) t.Run("zero ports", func(t *testing.T) { p := templ.Copy() @@ -801,7 +801,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { }) t.Run("Any proto, some ports allowed", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", "")) t.Run("zero ports, still blocked", func(t *testing.T) { p := templ.Copy() @@ -843,7 +843,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { }) }) t.Run("Any proto, any port", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) t.Run("zero ports, allowed", func(t *testing.T) { resetConntrack(fw) @@ -908,7 +908,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { } h1.buildNetworks(myVpnNetworksTable, c1.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -1420,7 +1420,7 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { for _, prefix := range c.Networks() { myVpnNetworksTable.Insert(prefix) } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) return testsetup{ @@ -1572,7 +1572,7 @@ func TestFirewall_SNAT(t *testing.T) { t.Parallel() myCert := MyCert.Copy() setup := newSnatSetup(t, l, myPrefix, snatAddr) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) resetConntrack(setup.fw) h := buildHostinfo(setup, theirPrefix) diff --git a/snat_test.go b/snat_test.go index 8fb43a2b..14d55f4a 100644 --- a/snat_test.go +++ b/snat_test.go @@ -413,7 +413,7 @@ func TestFirewall_FindUsableSNATPort(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fp := firewall.Packet{ LocalAddr: netip.MustParseAddr("192.168.1.1"), @@ -434,7 +434,7 @@ func TestFirewall_FindUsableSNATPort(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fp := firewall.Packet{ LocalAddr: netip.MustParseAddr("192.168.1.1"), @@ -459,7 +459,7 @@ func TestFirewall_FindUsableSNATPort(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) // Fill all ports baseFP := firewall.Packet{ @@ -498,7 +498,7 @@ func TestFirewall_ApplySnat(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -531,7 +531,7 @@ func TestFirewall_ApplySnat(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -564,7 +564,7 @@ func TestFirewall_ApplySnat(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -593,7 +593,7 @@ func TestFirewall_ApplySnat(t *testing.T) { c := &dummyCert{ networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) pkt := slices.Clone(canonicalUDPTest) fp := firewall.Packet{ @@ -615,7 +615,7 @@ func TestFirewall_ApplySnat(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -648,7 +648,7 @@ func TestFirewall_UnSnat(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr // Create a conntrack entry for the snatted flow @@ -693,7 +693,7 @@ func TestFirewall_UnSnat(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPReply) @@ -727,7 +727,7 @@ func TestFirewall_Drop_SNATFullFlow(t *testing.T) { issuer: "signer-shasum", } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) fw.snatAddr = snatAddr require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) @@ -816,7 +816,7 @@ func TestFirewall_ApplySnat_CrossHostHijack(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr // Simulate Host A having established a flow @@ -860,7 +860,7 @@ func TestFirewall_ApplySnat_MixedStackRejected(t *testing.T) { } t.Run("v6 first then v4", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -887,7 +887,7 @@ func TestFirewall_ApplySnat_MixedStackRejected(t *testing.T) { }) t.Run("v4 first then v6", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -923,7 +923,7 @@ func TestFirewall_ApplySnat_PacketUnmodifiedOnError(t *testing.T) { c := &dummyCert{ networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) pkt := slices.Clone(canonicalUDPTest) @@ -948,7 +948,7 @@ func TestFirewall_ApplySnat_PacketUnmodifiedOnError(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -980,7 +980,7 @@ func TestFirewall_ApplySnat_PacketUnmodifiedOnError(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr pkt := slices.Clone(canonicalUDPTest) @@ -1013,7 +1013,7 @@ func TestFirewall_UnSnat_NonSNATConntrack(t *testing.T) { networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")}, unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw.snatAddr = snatAddr // Create a conntrack entry with snat=nil (a normal non-SNAT connection) @@ -1061,7 +1061,7 @@ func TestFirewall_Drop_FirewallBlocksSNAT(t *testing.T) { issuer: "signer-shasum", } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) fw.snatAddr = snatAddr // Only allow port 80 inbound require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 80, []string{"any"}, "", "", "any", "", "")) @@ -1121,7 +1121,7 @@ func TestFirewall_Drop_SNATLocalAddrNotRoutable(t *testing.T) { issuer: "signer-shasum", } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) fw.snatAddr = snatAddr require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) @@ -1176,7 +1176,7 @@ func TestFirewall_Drop_NoSnatAddrRejectsV6Peer(t *testing.T) { issuer: "signer-shasum", } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, netip.Addr{}) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) peerV6Addr := netip.MustParseAddr("fd00::2") @@ -1236,7 +1236,7 @@ func TestFirewall_Drop_IPv4HostNotSNATted(t *testing.T) { issuer: "signer-shasum", } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert) fw.snatAddr = snatAddr require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "any", "", "")) From 5cbccdc0fdf66a036e836464bf2d954e29eb63e4 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 26 Feb 2026 11:48:53 -0600 Subject: [PATCH 22/31] remove dead comment --- overlay/tun_linux.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 19a8952b..f8e41710 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -454,15 +454,6 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to run tun device: %s", err) } - //todo hmmmmmm - //pretty sure this is avoidable - //if len(t.unsafeNetworks) != 0 { - // err = os.WriteFile(fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/accept_local", t.Device), []byte("1"), os.FileMode(0o644)) - // if err != nil { - // return err - // } - //} - return nil } From 7655a101089f810d0424e32392f48c73ac471418 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 27 Feb 2026 16:51:40 -0600 Subject: [PATCH 23/31] Remove thing --- firewall.go | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/firewall.go b/firewall.go index 4d8d7b3b..64ccb262 100644 --- a/firewall.go +++ b/firewall.go @@ -108,15 +108,6 @@ type FirewallConntrack struct { TimerWheel *TimerWheel[firewall.Packet] } -func (ct *FirewallConntrack) dupeConnUnlocked(fp firewall.Packet, c *conn, timeout time.Duration) { - if _, ok := ct.Conns[fp]; !ok { - ct.TimerWheel.Advance(time.Now()) - ct.TimerWheel.Add(fp, timeout) - } - - ct.Conns[fp] = c -} - // FirewallTable is the entry point for a rule, the evaluation order is: // Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR) type FirewallTable struct { @@ -505,7 +496,7 @@ func (f *Firewall) findUsableSNATPort(fp *firewall.Packet, c *conn) error { if !ok { //yay, we can use this port //track the snatted flow with the same expiration as the unsnatted version - conntrack.dupeConnUnlocked(*fp, c, f.packetTimeout(*fp)) + conntrack.Conns[*fp] = c return nil } //increment and retry. There's probably better strategies out there From 037459ef73f3d2ecd539f71e90ca80a0939a7700 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 27 Feb 2026 17:49:31 -0600 Subject: [PATCH 24/31] Review nits --- firewall.go | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/firewall.go b/firewall.go index 64ccb262..85e9f666 100644 --- a/firewall.go +++ b/firewall.go @@ -23,9 +23,14 @@ import ( "github.com/slackhq/nebula/firewall" ) -var ErrCannotSNAT = errors.New("cannot snat this packet") +var ErrCannotSNAT = errors.New("cannot SNAT this packet") var ErrSNATIdentityMismatch = errors.New("refusing to SNAT for mismatched host") +const ipv4SourcePosition = 12 +const ipv4DestinationPosition = 16 +const sourcePortOffset = 0 +const destinationPortOffset = 2 + type FirewallInterface interface { AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error } @@ -459,7 +464,7 @@ func (f *Firewall) unSnat(data []byte, fp *firewall.Packet) netip.Addr { return netip.Addr{} } oldIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort) - rewritePacket(data, fp, oldIP, c.snat.Src, 16, 2) + rewritePacket(data, fp, oldIP, c.snat.Src, ipv4DestinationPosition, destinationPortOffset) return c.snat.SrcVpnIp } @@ -496,6 +501,7 @@ func (f *Firewall) findUsableSNATPort(fp *firewall.Packet, c *conn) error { if !ok { //yay, we can use this port //track the snatted flow with the same expiration as the unsnatted version + c.snat.SnatPort = fp.RemotePort conntrack.Conns[*fp] = c return nil } @@ -538,13 +544,12 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo c.snat = nil return err } - c.snat.SnatPort = fp.RemotePort //may have been updated inside f.findUsableSNATPort } else { return ErrCannotSNAT } newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort) - rewritePacket(data, fp, c.snat.Src, newIP, 12, 0) + rewritePacket(data, fp, c.snat.Src, newIP, ipv4SourcePosition, sourcePortOffset) return nil } @@ -695,18 +700,9 @@ func (f *Firewall) EmitStats() { } func (f *Firewall) peek(fp firewall.Packet) *conn { - conntrack := f.Conntrack - conntrack.Lock() - - // Purge every time we test - ep, has := conntrack.TimerWheel.Purge() - if has { - f.evict(ep) - } - - c := conntrack.Conns[fp] - - conntrack.Unlock() + f.Conntrack.Lock() + c := f.Conntrack.Conns[fp] + f.Conntrack.Unlock() return c } From d21baede1feff99057d6e9669d4ca300c31776b5 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 27 Feb 2026 18:09:52 -0600 Subject: [PATCH 25/31] Nits and fix tests --- inside.go | 2 +- interface.go | 1 - overlay/tun_linux.go | 7 ++----- snat_test.go | 6 +++--- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/inside.go b/inside.go index 3f7cd19e..a4413aa0 100644 --- a/inside.go +++ b/inside.go @@ -235,7 +235,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(*fp, nil, false, hostinfo, f.pki.GetCAPool(), nil) + dropReason := f.firewall.Drop(*fp, p, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). diff --git a/interface.go b/interface.go index 83d313b5..0acbc147 100644 --- a/interface.go +++ b/interface.go @@ -56,7 +56,6 @@ type Interface struct { inside overlay.Device pki *PKI firewall *Firewall - snatAddr netip.Addr connectionManager *connectionManager handshakeManager *HandshakeManager serveDns bool diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index f8e41710..9e6a7581 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -436,7 +436,7 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to set default route MTU for %s: %w", t.vpnNetworks[i], err) } } - //TODO snat and be snatted + if t.unsafeIPv4Origin.IsValid() { if err = t.setDefaultRoute(t.unsafeIPv4Origin); err != nil { return fmt.Errorf("failed to set default route MTU for %s: %w", t.unsafeIPv4Origin, err) @@ -475,10 +475,7 @@ func (t *tun) setSnatRoute() error { nr := netlink.Route{ LinkIndex: t.deviceIndex, Dst: dr, - //todo do we need these other options? - //MTU: t.DefaultMTU, - //AdvMSS: t.advMSS(Route{}), - Scope: unix.RT_SCOPE_LINK, + Scope: unix.RT_SCOPE_LINK, //Protocol: unix.RTPROT_KERNEL, Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, diff --git a/snat_test.go b/snat_test.go index 14d55f4a..708d96a0 100644 --- a/snat_test.go +++ b/snat_test.go @@ -422,7 +422,7 @@ func TestFirewall_FindUsableSNATPort(t *testing.T) { RemotePort: 12345, Protocol: firewall.ProtoUDP, } - cn := &conn{} + cn := &conn{snat: &snatInfo{}} err := fw.findUsableSNATPort(&fp, cn) require.NoError(t, err) // Port should have been assigned @@ -448,7 +448,7 @@ func TestFirewall_FindUsableSNATPort(t *testing.T) { fw.Conntrack.Conns[fp] = &conn{} fw.Conntrack.Unlock() - cn := &conn{} + cn := &conn{snat: &snatInfo{}} err := fw.findUsableSNATPort(&fp, cn) require.NoError(t, err) assert.NotEqual(t, uint16(12345), fp.RemotePort, "should pick a different port") @@ -479,7 +479,7 @@ func TestFirewall_FindUsableSNATPort(t *testing.T) { // Try to find a port starting from 0x8000 fp := baseFP fp.RemotePort = 0x8000 - cn := &conn{} + cn := &conn{snat: &snatInfo{}} err := fw.findUsableSNATPort(&fp, cn) assert.ErrorIs(t, err, ErrCannotSNAT) }) From 09fe406dba8efc24b65fbc175cee1ff7f8f0daf5 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 4 Mar 2026 12:08:00 -0600 Subject: [PATCH 26/31] log if V1 and V2 certs have mismatched UnsafeNetworks --- firewall_test.go | 2 +- pki.go | 37 +++++++++++++++++++++++++++++++++---- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/firewall_test.go b/firewall_test.go index 4df6eadd..dc863319 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -1042,7 +1042,7 @@ func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition c := &dummyCert{} - cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) + cs, err := newCertState(l, cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) require.NoError(t, err) conf := config.NewC(l) diff --git a/pki.go b/pki.go index 19869d58..5744e5af 100644 --- a/pki.go +++ b/pki.go @@ -91,7 +91,7 @@ func (p *PKI) reload(c *config.C, initial bool) error { } func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { - newState, err := newCertStateFromConfig(c) + newState, err := newCertStateFromConfig(c, p.l) if err != nil { return util.NewContextualError("Could not load client cert", nil, err) } @@ -260,7 +260,7 @@ func (cs *CertState) MarshalJSON() ([]byte, error) { return json.Marshal(msg) } -func newCertStateFromConfig(c *config.C) (*CertState, error) { +func newCertStateFromConfig(c *config.C, l *logrus.Logger) (*CertState, error) { var err error privPathOrPEM := c.GetString("pki.key", "") @@ -344,10 +344,33 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion) } - return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey) + return newCertState(l, initiatingVersion, v1, v2, isPkcs11, curve, rawKey) } -func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { +func compareUnsafeNetworksAcrossCertVersions(v1, v2 cert.Certificate) error { + if v1 == nil || v2 == nil { + return nil //can't be a problem if we don't have one of the kinds of cert + } + + v4UnsafeNets := 0 + for _, n := range v2.UnsafeNetworks() { + if n.Addr().Is6() { + continue // V1 certs can't have IPv6 unsafe networks + } else { + v4UnsafeNets++ + } + if !slices.Contains(v1.UnsafeNetworks(), n) { + return errors.New("UnsafeNetworks mismatch") + } + } + if len(v1.UnsafeNetworks()) != v4UnsafeNets { + return errors.New("UnsafeNetworks mismatch") + } + + return nil +} + +func newCertState(l *logrus.Logger, dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { cs := CertState{ privateKey: privateKey, pkcs11Backed: pkcs11backed, @@ -370,6 +393,12 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p } cs.initiatingVersion = dv + + warn := compareUnsafeNetworksAcrossCertVersions(v1, v2) + if warn != nil { + l.WithFields(m{"UnsafeNetworksV1": v1.UnsafeNetworks(), "UnsafeNetworksV2": v2.UnsafeNetworks()}). + Warning("the IPv4 UnsafeNetworks assigned in the V1 certificate do not match the ones in V2") + } } if v1 != nil { From 36bbc515d2df2e652cfd7270cb3e70aee1d30e8d Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 4 Mar 2026 12:33:16 -0600 Subject: [PATCH 27/31] log if UnsafeNetworks assignment changes across reload --- pki.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pki.go b/pki.go index 5744e5af..3deb0fe7 100644 --- a/pki.go +++ b/pki.go @@ -102,7 +102,7 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { if currentState.v1Cert == nil { //adding certs is fine, actually. Networks-in-common confirmed in newCertState(). } else { - // did IP in cert change? if so, don't set + // did IP in cert change? if so, don't set. If we ever allow this, need to set p.firewallReloadNeeded if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { return util.NewContextualError( "Networks in new cert was different from old", @@ -158,6 +158,14 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { } } + newUN := newState.GetDefaultCertificate().UnsafeNetworks() + oldUN := currentState.GetDefaultCertificate().UnsafeNetworks() + if !slices.Equal(newUN, oldUN) { + //todo I don't love this, because other clients will see the new assignments and act on them, but we will not be able to. + //I think we need to wire this into the firewall reload. + p.l.WithFields(m{"previous": oldUN, "new": newUN}).Warning("UnsafeNetworks assignments differ. A restart is required in order for this to take effect.") + } + // Cipher cant be hot swapped so just leave it at what it was before newState.cipher = currentState.cipher From 1580175b2e6d6ce19b35282b52492ddea3ae5b22 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 4 Mar 2026 12:36:23 -0600 Subject: [PATCH 28/31] remove silly panic --- firewall.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firewall.go b/firewall.go index 85e9f666..6146a058 100644 --- a/firewall.go +++ b/firewall.go @@ -223,12 +223,12 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { certificate := cs.getCertificate(cert.Version2) - if certificate == nil { + if certificate == nil { //todo if config.initiating_version is set to 1, and unsafe_networks differ, things will suck certificate = cs.getCertificate(cert.Version1) } if certificate == nil { - panic("No certificate available to reconfigure the firewall") + return nil, errors.New("no certificate available to reconfigure the firewall") } fw := NewFirewall( From 2e50518066ce379f7b1e55c9a2814523af10c92b Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 4 Mar 2026 13:08:15 -0600 Subject: [PATCH 29/31] new error --- firewall.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/firewall.go b/firewall.go index 6146a058..dea7a9c8 100644 --- a/firewall.go +++ b/firewall.go @@ -25,6 +25,7 @@ import ( var ErrCannotSNAT = errors.New("cannot SNAT this packet") var ErrSNATIdentityMismatch = errors.New("refusing to SNAT for mismatched host") +var ErrSNATAddressCollision = errors.New("refusing to accept an incoming packet with my SNAT address") const ipv4SourcePosition = 12 const ipv4DestinationPosition = 16 @@ -522,7 +523,7 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo return ErrCannotSNAT } if f.snatAddr == fp.LocalAddr { //a packet that came from UDP (incoming) should never ever have our snat address on it - return ErrSNATIdentityMismatch + return ErrSNATAddressCollision } if c.snat.Valid() { //old flow: make sure it came from the right place From a2c2235b9bdc5b047a9935e36a1509a97e02e2a3 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 4 Mar 2026 13:11:23 -0600 Subject: [PATCH 30/31] rename --- firewall.go | 6 +++--- hostmap.go | 4 ++-- snat_test.go | 10 +++++----- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/firewall.go b/firewall.go index dea7a9c8..27d48171 100644 --- a/firewall.go +++ b/firewall.go @@ -566,7 +566,7 @@ func (f *Firewall) identifyRemoteNetworkType(h *HostInfo, fp firewall.Packet) Ne //RemoteAddr not in our networks table if f.snatAddr.IsValid() && fp.IsIPv4() && h.HasOnlyV6Addresses() { - return NetworkTypeUncheckedSNATPeer + return NetworkTypeUnverifiedSNATPeer } else { return NetworkTypeInvalidPeer } @@ -583,7 +583,7 @@ func (f *Firewall) allowRemoteNetworkType(nwType NetworkType, fp firewall.Packet return ErrPeerRejected // reject for now, one day this may have different FW rules case NetworkTypeUnsafe: return nil // nothing special, one day this may have different FW rules - case NetworkTypeUncheckedSNATPeer: + case NetworkTypeUnverifiedSNATPeer: if f.unsafeIPv4Origin.IsValid() && fp.LocalAddr == f.unsafeIPv4Origin { return nil //the client case } @@ -668,7 +668,7 @@ func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostIn // We always want to conntrack since it is a faster operation c = f.addConn(fp, incoming) - if incoming && remoteNetworkType == NetworkTypeUncheckedSNATPeer { + if incoming && remoteNetworkType == NetworkTypeUnverifiedSNATPeer { return f.applySnat(pkt, &fp, c, h) } else { //outgoing snat is handled before this function is called diff --git a/hostmap.go b/hostmap.go index ff5ee456..f50dd875 100644 --- a/hostmap.go +++ b/hostmap.go @@ -224,8 +224,8 @@ const ( NetworkTypeVPNPeer // NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks() NetworkTypeUnsafe - // NetworkTypeUncheckedSNATPeer is used to indicate traffic we're willing to route, but never deliver to a NetworkTypeVPN - NetworkTypeUncheckedSNATPeer + // NetworkTypeUnverifiedSNATPeer is used to indicate traffic we're willing to route, but never deliver to a NetworkTypeVPN + NetworkTypeUnverifiedSNATPeer NetworkTypeInvalidPeer ) diff --git a/snat_test.go b/snat_test.go index 708d96a0..1bba83ce 100644 --- a/snat_test.go +++ b/snat_test.go @@ -335,7 +335,7 @@ func TestFirewall_IdentifyNetworkType_SNATPeer(t *testing.T) { RemoteAddr: netip.MustParseAddr("10.0.0.1"), LocalAddr: netip.MustParseAddr("192.168.1.1"), } - assert.Equal(t, NetworkTypeUncheckedSNATPeer, fw.identifyRemoteNetworkType(h, fp)) + assert.Equal(t, NetworkTypeUnverifiedSNATPeer, fw.identifyRemoteNetworkType(h, fp)) }) t.Run("v4 packet from v4 host is not snat peer", func(t *testing.T) { @@ -373,12 +373,12 @@ func TestFirewall_AllowNetworkType_SNAT(t *testing.T) { //todo fix! //t.Run("snat peer allowed with snat addr", func(t *testing.T) { // fw := &Firewall{snatAddr: netip.MustParseAddr("169.254.55.96")} - // assert.NoError(t, fw.allowRemoteNetworkType(NetworkTypeUncheckedSNATPeer, fp)) + // assert.NoError(t, fw.allowRemoteNetworkType(NetworkTypeUnverifiedSNATPeer, fp)) //}) // //t.Run("snat peer rejected without snat addr", func(t *testing.T) { // fw := &Firewall{} - // assert.ErrorIs(t, fw.allowRemoteNetworkType(NetworkTypeUncheckedSNATPeer, fp), ErrInvalidRemoteIP) + // assert.ErrorIs(t, fw.allowRemoteNetworkType(NetworkTypeUnverifiedSNATPeer, fp), ErrInvalidRemoteIP) //}) t.Run("vpn always allowed", func(t *testing.T) { @@ -1291,7 +1291,7 @@ func TestFirewall_Drop_IPv4HostNotSNATted(t *testing.T) { } nwType := fw.identifyRemoteNetworkType(h, fp) assert.Equal(t, NetworkTypeVPN, nwType, "v4 peer using its own VPN addr should be NetworkTypeVPN") - assert.NotEqual(t, NetworkTypeUncheckedSNATPeer, nwType, "must NOT be classified as SNAT peer") + assert.NotEqual(t, NetworkTypeUnverifiedSNATPeer, nwType, "must NOT be classified as SNAT peer") }) t.Run("identifyRemoteNetworkType v4 peer with mismatched source", func(t *testing.T) { @@ -1305,6 +1305,6 @@ func TestFirewall_Drop_IPv4HostNotSNATted(t *testing.T) { } nwType := fw.identifyRemoteNetworkType(h, fp) assert.Equal(t, NetworkTypeInvalidPeer, nwType, "v4 peer with mismatched source should be InvalidPeer") - assert.NotEqual(t, NetworkTypeUncheckedSNATPeer, nwType, "must NOT be classified as SNAT peer") + assert.NotEqual(t, NetworkTypeUnverifiedSNATPeer, nwType, "must NOT be classified as SNAT peer") }) } From 3e3bd9ceadf0eff5adaf19b317824964ed7b0f4d Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 4 Mar 2026 13:38:58 -0600 Subject: [PATCH 31/31] add some context for the next guy --- overlay/tun_linux.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 9e6a7581..1b70e8b3 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -560,6 +560,10 @@ func (t *tun) addRoutes(logErrors bool) error { } if t.snatAddr.IsValid() { + //at least for Linux, we need to set a return route for the SNATted traffic in order to satisfy the reverse-path filter, + //and to help the kernel deliver our reply traffic to the tun device. + //however, it is important that we do not actually /assign/ the SNAT address, + //since link-local addresses will not be routed between interfaces without significant trickery. return t.setSnatRoute() } return nil