From 935d0ba76fbd5448928916a8d938bb1fc50eb36b Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 14 Jan 2026 12:36:55 -0600 Subject: [PATCH 01/16] srcsnort prototype --- cert/cert_v2.go | 2 +- firewall.go | 18 ++++++---- firewall/packet.go | 4 +++ hostmap.go | 10 ++++++ inside.go | 17 ++++++++-- interface.go | 67 ++++++++++++++++++++++++++---------- outside.go | 85 ++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 176 insertions(+), 27 deletions(-) diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 4648c496..87d1ec11 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -441,7 +441,7 @@ func (c *certificateV2) validate() error { } } else if network.Addr().Is4() { if !hasV4Networks { - return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network) + //return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network) } } } diff --git a/firewall.go b/firewall.go index 45dc0691..1ce66221 100644 --- a/firewall.go +++ b/firewall.go @@ -408,15 +408,20 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * if f.inConns(fp, h, caPool, localCache) { return nil } + if fp.IsIPv4() && h.HasOnlyV6Addresses() { + //todo!!! special case: fp.RemoteAddr is v4, and cert is v6 only. We want to accept and do NAT internally + } // Make sure remote address matches nebula certificate, and determine how to treat it + if h.networks == nil { // Simple case: Certificate has one address and no unsafe networks - if h.vpnAddrs[0] != fp.RemoteAddr { + if h.vpnAddrs[0] != fp.RemoteAddr && fp.RemoteAddr != srcsnortaddr /*todo get this from interface */ { f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrInvalidRemoteIP + return ErrInvalidRemoteIP //todo! } } else { + //todo check for srcsnortaddr here too nwType, ok := h.networks.Lookup(fp.RemoteAddr) if !ok { f.metrics(incoming).droppedRemoteAddr.Inc(1) @@ -437,10 +442,11 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // Make sure we are supposed to be handling this local ip address - if !f.routableNetworks.Contains(fp.LocalAddr) { - f.metrics(incoming).droppedLocalAddr.Inc(1) - return ErrInvalidLocalIP - } + //todo probably bad! + //if !f.routableNetworks.Contains(fp.LocalAddr) { + // f.metrics(incoming).droppedLocalAddr.Inc(1) + // return ErrInvalidLocalIP + //} table := f.OutRules if incoming { diff --git a/firewall/packet.go b/firewall/packet.go index 40c7fc5d..ade2ad55 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -28,6 +28,10 @@ type Packet struct { Fragment bool } +func (fp *Packet) IsIPv4() bool { + return fp.LocalAddr.Is4() && fp.RemoteAddr.Is4() +} + func (fp *Packet) Copy() *Packet { return &Packet{ LocalAddr: fp.LocalAddr, diff --git a/hostmap.go b/hostmap.go index 7e2939e0..a451bd98 100644 --- a/hostmap.go +++ b/hostmap.go @@ -224,6 +224,7 @@ const ( NetworkTypeVPNPeer // NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks() NetworkTypeUnsafe + //todo consider NetworkTypeLinkLocal or NetworkTypeSNAT ) type HostInfo struct { @@ -277,6 +278,15 @@ type HostInfo struct { lastUsed time.Time } +func (i *HostInfo) HasOnlyV6Addresses() bool { + for _, vpnIp := range i.vpnAddrs { + if !vpnIp.Is6() { + return false + } + } + return true +} + type ViaSender struct { UdpAddr netip.AddrPort relayHI *HostInfo // relayHI is the host info object of the relay diff --git a/inside.go b/inside.go index 0d53f952..a7731a40 100644 --- a/inside.go +++ b/inside.go @@ -48,9 +48,20 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { - hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) - }) + var hostinfo *HostInfo + var ready bool + + if fwPacket.RemoteAddr == f.snatMaps.snatIP { + //todo unsnat happens here + hostinfo = f.unSnat(packet, fwPacket) //todo bail if we can't unsnat? + ready = hostinfo != nil //todo feels hacky and bad + } + + if hostinfo == nil { //if we didn't unsnat + hostinfo, ready = f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + } if hostinfo == nil { f.rejectInside(packet, out, q) diff --git a/interface.go b/interface.go index 9f83d183..10a82fe7 100644 --- a/interface.go +++ b/interface.go @@ -50,6 +50,26 @@ type InterfaceConfig struct { l *logrus.Logger } +type SnatMapping struct { + Src netip.AddrPort + SrcHostInfo *HostInfo +} + +type SnatMap struct { + m map[uint16]SnatMapping +} + +func (s *SnatMap) addMapping(src netip.AddrPort) { + +} + +type SnatMaps struct { + TCP SnatMap + UDP SnatMap + ICMP SnatMap //todo index? + snatIP netip.Addr +} + type Interface struct { hostMap *HostMap outside udp.Conn @@ -85,8 +105,9 @@ type Interface struct { conntrackCacheTimeout time.Duration - writers []udp.Conn - readers []io.ReadWriteCloser + writers []udp.Conn + readers []io.ReadWriteCloser + snatMaps *SnatMaps //todo this needs some kind of atomic semantics for cross-routine access metricHandshakes metrics.Histogram messageMetrics *MessageMetrics @@ -163,21 +184,33 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { cs := c.pki.getCertState() ifce := &Interface{ - pki: c.pki, - hostMap: c.HostMap, - outside: c.Outside, - inside: c.Inside, - firewall: c.Firewall, - serveDns: c.ServeDns, - handshakeManager: c.HandshakeManager, - createTime: time.Now(), - lightHouse: c.lightHouse, - dropLocalBroadcast: c.DropLocalBroadcast, - dropMulticast: c.DropMulticast, - routines: c.routines, - version: c.version, - writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), + pki: c.pki, + hostMap: c.HostMap, + outside: c.Outside, + inside: c.Inside, + firewall: c.Firewall, + serveDns: c.ServeDns, + handshakeManager: c.HandshakeManager, + createTime: time.Now(), + lightHouse: c.lightHouse, + dropLocalBroadcast: c.DropLocalBroadcast, + dropMulticast: c.DropMulticast, + routines: c.routines, + version: c.version, + writers: make([]udp.Conn, c.routines), + readers: make([]io.ReadWriteCloser, c.routines), + snatMaps: &SnatMaps{ + TCP: SnatMap{ + m: map[uint16]SnatMapping{}, + }, + UDP: SnatMap{ + m: map[uint16]SnatMapping{}, + }, + ICMP: SnatMap{ + m: map[uint16]SnatMapping{}, + }, + snatIP: srcsnortaddr, //todo this should be source of truthed here + }, myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, myVpnAddrs: cs.myVpnAddrs, diff --git a/outside.go b/outside.go index b1a28e57..965931b0 100644 --- a/outside.go +++ b/outside.go @@ -400,6 +400,85 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { return ErrIPv6CouldNotFindPayload } +var srcsnortaddr = netip.MustParseAddr("169.254.55.96") + +func CalculateIPv4Checksum(header []byte) uint16 { + //todo this should be elsewhere + headerLen := int(header[0]&0x0F) * 4 + + if len(header) < headerLen { + return 0 + } + + var sum uint32 + for i := 0; i < headerLen; i += 2 { + word := uint32(binary.BigEndian.Uint16(header[i : i+2])) + sum += word + } + + for sum > 0xFFFF { + sum = (sum & 0xFFFF) + (sum >> 16) + } + + return uint16(^sum) +} + +func recalcIPv4Checksum(data []byte) { + data[10] = 0 + data[11] = 0 + checksum := CalculateIPv4Checksum(data) + binary.BigEndian.PutUint16(data[10:12], checksum) +} + +func (f *Interface) unSnat(data []byte, fp *firewall.Packet) *HostInfo { + var mapping SnatMapping + var ok bool + switch fp.Protocol { + case firewall.ProtoICMP: + //todo hack + mapping, ok = f.snatMaps.ICMP.m[0] + default: + f.l.WithField("fwPacket", fp).Warn("Unsupported unSNAT protocol") + return nil + } + if !ok { + f.l.WithField("fwPacket", fp).Warn("got a snat packet we don't know how to unsnat") + return nil + } + + copy(data[16:], mapping.Src.Addr().AsSlice()) + + recalcIPv4Checksum(data) + return mapping.SrcHostInfo +} + +func (f *Interface) applySnat(data []byte, fp *firewall.Packet, hostinfo *HostInfo) { + if !f.snatMaps.snatIP.Is4() { + return //bad! + } + + //todo math should exist to take existing checksum, old ip, new ip, and set new checksum, right? + + //todo set srcport + //todo record mapping somehow??? sadly the somehow has to be safe/sane across all routines + switch fp.Protocol { + case firewall.ProtoICMP, firewall.ProtoICMPv6: + f.snatMaps.ICMP.m[0] = SnatMapping{ + Src: netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort), + SrcHostInfo: hostinfo, + } + case firewall.ProtoTCP: + //todo + case firewall.ProtoUDP: + //also todo + } + + fp.RemoteAddr = f.snatMaps.snatIP + copy(data[12:], f.snatMaps.snatIP.AsSlice()) + + recalcIPv4Checksum(data) +} + func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { // Do we at least have an ipv4 header worth of data? if len(data) < ipv4.HeaderLen { @@ -494,6 +573,12 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } + //todo apply srcsnort here? + //todo rp_filter will need to be set or defeated somehow + if fwPacket.IsIPv4() && hostinfo.HasOnlyV6Addresses() { + f.applySnat(out, fwPacket, hostinfo) + } + dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore From 68831a1d551cb69b52c3ddb23d88480fd5a98c8a Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 14 Jan 2026 13:09:12 -0600 Subject: [PATCH 02/16] hack the firewall a little less --- firewall.go | 25 +++++++++++++++---------- outside.go | 6 +++++- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/firewall.go b/firewall.go index 1ce66221..f2e3b2f2 100644 --- a/firewall.go +++ b/firewall.go @@ -400,6 +400,7 @@ var ErrPeerRejected = errors.New("remote address is not within a network that we var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks") var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses") var ErrNoMatchingRule = errors.New("no matching rule in firewall table") +var ErrSnatRequired = errors.New("snat required to pass traffic") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. @@ -408,12 +409,12 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * if f.inConns(fp, h, caPool, localCache) { return nil } - if fp.IsIPv4() && h.HasOnlyV6Addresses() { - //todo!!! special case: fp.RemoteAddr is v4, and cert is v6 only. We want to accept and do NAT internally - } // Make sure remote address matches nebula certificate, and determine how to treat it + var err error + specialSnatMode := false + if h.networks == nil { // Simple case: Certificate has one address and no unsafe networks if h.vpnAddrs[0] != fp.RemoteAddr && fp.RemoteAddr != srcsnortaddr /*todo get this from interface */ { @@ -421,7 +422,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * return ErrInvalidRemoteIP //todo! } } else { - //todo check for srcsnortaddr here too + //todo check for srcsnortaddr here too? nwType, ok := h.networks.Lookup(fp.RemoteAddr) if !ok { f.metrics(incoming).droppedRemoteAddr.Inc(1) @@ -434,6 +435,10 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrPeerRejected // reject for now, one day this may have different FW rules case NetworkTypeUnsafe: + if fp.IsIPv4() && h.HasOnlyV6Addresses() { + //err = ErrSnatRequired + specialSnatMode = true + } break // nothing special, one day this may have different FW rules default: f.metrics(incoming).droppedRemoteAddr.Inc(1) @@ -442,11 +447,11 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // Make sure we are supposed to be handling this local ip address - //todo probably bad! - //if !f.routableNetworks.Contains(fp.LocalAddr) { - // f.metrics(incoming).droppedLocalAddr.Inc(1) - // return ErrInvalidLocalIP - //} + //todo I'm not sure I trust this heuristic + if !specialSnatMode && !f.routableNetworks.Contains(fp.LocalAddr) { + f.metrics(incoming).droppedLocalAddr.Inc(1) + return ErrInvalidLocalIP + } table := f.OutRules if incoming { @@ -462,7 +467,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * // We always want to conntrack since it is a faster operation f.addConn(fp, incoming) - return nil + return err } func (f *Firewall) metrics(incoming bool) firewallMetrics { diff --git a/outside.go b/outside.go index 965931b0..b994a032 100644 --- a/outside.go +++ b/outside.go @@ -576,7 +576,11 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out //todo apply srcsnort here? //todo rp_filter will need to be set or defeated somehow if fwPacket.IsIPv4() && hostinfo.HasOnlyV6Addresses() { - f.applySnat(out, fwPacket, hostinfo) + if len(f.pki.getCertState().GetDefaultCertificate().UnsafeNetworks()) != 0 { + //todo do not snat if you are not a router for the destination -- for now, just if you're not a router + //f.myVpnNetworksTable.Contains(fwPacket.RemoteAddr) + f.applySnat(out, fwPacket, hostinfo) + } } dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) From 213dd2573308f368edf5b1be639906f39f3c3efd Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 14 Jan 2026 15:01:35 -0600 Subject: [PATCH 03/16] don't abuse hostinfo --- inside.go | 10 +++++++--- interface.go | 4 ++-- outside.go | 12 ++++++------ 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/inside.go b/inside.go index a7731a40..3735e470 100644 --- a/inside.go +++ b/inside.go @@ -49,15 +49,19 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } var hostinfo *HostInfo + var destVpnAddr netip.Addr var ready bool if fwPacket.RemoteAddr == f.snatMaps.snatIP { //todo unsnat happens here - hostinfo = f.unSnat(packet, fwPacket) //todo bail if we can't unsnat? - ready = hostinfo != nil //todo feels hacky and bad + destVpnAddr = f.unSnat(packet, fwPacket) //todo bail if we can't unsnat? } - if hostinfo == nil { //if we didn't unsnat + if destVpnAddr.IsValid() { + hostinfo, ready = f.getOrHandshakeNoRouting(destVpnAddr, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + } else { //if we didn't need to unsnat hostinfo, ready = f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) }) diff --git a/interface.go b/interface.go index 10a82fe7..aa9ba50c 100644 --- a/interface.go +++ b/interface.go @@ -51,8 +51,8 @@ type InterfaceConfig struct { } type SnatMapping struct { - Src netip.AddrPort - SrcHostInfo *HostInfo + Src netip.AddrPort + SrcVpnIp netip.Addr } type SnatMap struct { diff --git a/outside.go b/outside.go index b994a032..e5137e48 100644 --- a/outside.go +++ b/outside.go @@ -430,7 +430,7 @@ func recalcIPv4Checksum(data []byte) { binary.BigEndian.PutUint16(data[10:12], checksum) } -func (f *Interface) unSnat(data []byte, fp *firewall.Packet) *HostInfo { +func (f *Interface) unSnat(data []byte, fp *firewall.Packet) netip.Addr { var mapping SnatMapping var ok bool switch fp.Protocol { @@ -439,17 +439,17 @@ func (f *Interface) unSnat(data []byte, fp *firewall.Packet) *HostInfo { mapping, ok = f.snatMaps.ICMP.m[0] default: f.l.WithField("fwPacket", fp).Warn("Unsupported unSNAT protocol") - return nil + return netip.Addr{} } if !ok { f.l.WithField("fwPacket", fp).Warn("got a snat packet we don't know how to unsnat") - return nil + return netip.Addr{} } copy(data[16:], mapping.Src.Addr().AsSlice()) recalcIPv4Checksum(data) - return mapping.SrcHostInfo + return mapping.SrcVpnIp } func (f *Interface) applySnat(data []byte, fp *firewall.Packet, hostinfo *HostInfo) { @@ -464,8 +464,8 @@ func (f *Interface) applySnat(data []byte, fp *firewall.Packet, hostinfo *HostIn switch fp.Protocol { case firewall.ProtoICMP, firewall.ProtoICMPv6: f.snatMaps.ICMP.m[0] = SnatMapping{ - Src: netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort), - SrcHostInfo: hostinfo, + Src: netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort), + SrcVpnIp: hostinfo.vpnAddrs[0], //todo I hope this is ipv6 } case firewall.ProtoTCP: //todo From 293105fb8066398217677b534d14ce4295490d2f Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 14 Jan 2026 19:23:01 -0600 Subject: [PATCH 04/16] udp "works" --- firewall.go | 198 ++++++++++++++++++++++++++++++++++++++++------- firewall_test.go | 60 +++++++------- inside.go | 23 +++--- interface.go | 67 ++++------------ outside.go | 91 +--------------------- snat.go | 72 +++++++++++++++++ 6 files changed, 299 insertions(+), 212 deletions(-) create mode 100644 snat.go diff --git a/firewall.go b/firewall.go index f2e3b2f2..3ec512d6 100644 --- a/firewall.go +++ b/firewall.go @@ -26,6 +26,15 @@ type FirewallInterface interface { AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error } +type snatInfo struct { + Src netip.AddrPort + SrcVpnIp netip.Addr +} + +func (s *snatInfo) Valid() bool { + return s.Src.IsValid() +} + type conn struct { Expires time.Time // Time when this conntrack entry will expire @@ -34,6 +43,9 @@ type conn struct { // fields pack for free after the uint32 above incoming bool rulesVersion uint16 + + //for SNAT support + snat snatInfo } // TODO: need conntrack max tracked connections handling @@ -66,6 +78,7 @@ type Firewall struct { defaultLocalCIDRAny bool incomingMetrics firewallMetrics outgoingMetrics firewallMetrics + snatAddr netip.Addr l *logrus.Logger } @@ -81,6 +94,11 @@ type FirewallConntrack struct { Conns map[firewall.Packet]*conn TimerWheel *TimerWheel[firewall.Packet] + // SNATFlows maps protocol->source_port->original packet info for unsnatting. + // the srcport to use for outgoing snat flows is stored in Conns. + // When a flow is expired from Conns, it needs to be removed from SNATFlows as well. + // todo if we put "both" keys into Conns, we can potentially avoid this problem + SNATFlows map[int]map[uint16]snatInfo } // FirewallTable is the entry point for a rule, the evaluation order is: @@ -163,6 +181,11 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D hasUnsafeNetworks = true } + snatAddr := netip.Addr{} + if hasUnsafeNetworks { + snatAddr = netip.MustParseAddr("169.254.55.96") + } + return &Firewall{ Conntrack: &FirewallConntrack{ Conns: make(map[firewall.Packet]*conn), @@ -176,6 +199,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D routableNetworks: routableNetworks, assignedNetworks: assignedNetworks, hasUnsafeNetworks: hasUnsafeNetworks, + snatAddr: snatAddr, l: l, incomingMetrics: firewallMetrics{ @@ -402,24 +426,88 @@ var ErrInvalidLocalIP = errors.New("local address is not in list of handled loca var ErrNoMatchingRule = errors.New("no matching rule in firewall table") var ErrSnatRequired = errors.New("snat required to pass traffic") -// Drop returns an error if the packet should be dropped, explaining why. It -// returns nil if the packet should not be dropped. -func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { - // Check if we spoke to this tuple, if we did then allow this packet - if f.inConns(fp, h, caPool, localCache) { - return nil +func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn, caPool *cert.CAPool) netip.Addr { + if c == nil { + //unfortunately this needs to lock. Surely there's a better way, but I need to make this flow at all first. + c = f.peek(*fp, nil, caPool, nil) + } + if c == nil { + return netip.Addr{} + } + if !c.snat.Valid() { + return netip.Addr{} } - // Make sure remote address matches nebula certificate, and determine how to treat it + copy(data[16:], c.snat.Src.Addr().AsSlice()) + + recalcIPv4Checksum(data) + switch fp.Protocol { + case firewall.ProtoUDP: + recalcUDPv4Checksum(data, f.snatAddr, c.snat.Src.Addr(), fp.RemotePort, c.snat.Src.Port()) + } + return c.snat.SrcVpnIp +} + +func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) { + //todo math should exist to take existing checksum, old ip, new ip, and set new checksum, right? + + //todo set srcport + //todo record mapping somehow??? sadly the somehow has to be safe/sane across all routines + //switch fp.Protocol { + //case firewall.ProtoICMP: + // + // + //case firewall.ProtoTCP: + // //todo + //case firewall.ProtoUDP: + // //also todo + //} + // + c.snat.Src = netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort) + c.snat.SrcVpnIp = hostinfo.vpnAddrs[0] //todo I hope this is ipv6 + fp.RemoteAddr = f.snatAddr + + copy(data[12:], f.snatAddr.AsSlice()) + + recalcIPv4Checksum(data) + switch fp.Protocol { + case firewall.ProtoUDP: + //todo change the port + recalcUDPv4Checksum(data, c.snat.Src.Addr(), f.snatAddr, c.snat.Src.Port(), fp.RemotePort) + case firewall.ProtoTCP: + //todo recalc checksum + } +} + +// Drop returns an error if the packet should be dropped, explaining why. It +// returns nil if the packet should not be dropped. +func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { + // Check if we spoke to this tuple, if we did then allow this packet + c := f.inConns(fp, h, caPool, localCache) + //can't return yet, need to snat maybe var err error specialSnatMode := false + table := f.OutRules + if incoming { + table = f.InRules + } + + if c != nil { + specialSnatMode = fp.IsIPv4() && h.HasOnlyV6Addresses() //todo? + goto snat + } + + // Make sure remote address matches nebula certificate, and determine how to treat it if h.networks == nil { // Simple case: Certificate has one address and no unsafe networks - if h.vpnAddrs[0] != fp.RemoteAddr && fp.RemoteAddr != srcsnortaddr /*todo get this from interface */ { - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrInvalidRemoteIP //todo! + if h.vpnAddrs[0] != fp.RemoteAddr { + specialSnatMode = fp.IsIPv4() && h.HasOnlyV6Addresses() + if !specialSnatMode { + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrInvalidRemoteIP //todo! + } } } else { //todo check for srcsnortaddr here too? @@ -435,10 +523,7 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrPeerRejected // reject for now, one day this may have different FW rules case NetworkTypeUnsafe: - if fp.IsIPv4() && h.HasOnlyV6Addresses() { - //err = ErrSnatRequired - specialSnatMode = true - } + specialSnatMode = fp.IsIPv4() && h.HasOnlyV6Addresses() break // nothing special, one day this may have different FW rules default: f.metrics(incoming).droppedRemoteAddr.Inc(1) @@ -453,11 +538,6 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * return ErrInvalidLocalIP } - table := f.OutRules - if incoming { - table = f.InRules - } - // We now know which firewall table to check against if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) { f.metrics(incoming).droppedNoRule.Inc(1) @@ -465,7 +545,22 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // We always want to conntrack since it is a faster operation - f.addConn(fp, incoming) + c = f.addConn(fp, incoming) + +snat: + if incoming { + //todo rp_filter will need to be set or defeated somehow + if specialSnatMode { + if f.hasUnsafeNetworks { + //todo do not snat if you are not a router for the destination -- for now, just if you're not a router + //f.myVpnNetworksTable.Contains(fwPacket.RemoteAddr) + f.applySnat(pkt, &fp, c, h) + f.dupeConn(fp, c) + } + } + } else { //outgoing + + } return err } @@ -494,12 +589,35 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { - if localCache != nil { - if _, ok := localCache[fp]; ok { - return true - } +func (f *Firewall) peek(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn { + //big todo, this cache needs to know snat info as well + //if localCache != nil { + // if out, ok := localCache[fp]; ok { + // return out + // } + //} + conntrack := f.Conntrack + conntrack.Lock() + + // Purge every time we test + ep, has := conntrack.TimerWheel.Purge() + if has { + f.evict(ep) } + + c := conntrack.Conns[fp] + + conntrack.Unlock() + return c +} + +func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn { + //big todo, this cache needs to know snat info as well + //if localCache != nil { + // if out, ok := localCache[fp]; ok { + // return out + // } + //} conntrack := f.Conntrack conntrack.Lock() @@ -513,7 +631,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, if !ok { conntrack.Unlock() - return false + return nil } if c.rulesVersion != f.rulesVersion { @@ -536,7 +654,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, } delete(conntrack.Conns, fp) conntrack.Unlock() - return false + return nil } if f.l.Level >= logrus.DebugLevel { @@ -566,12 +684,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache[fp] = struct{}{} } - return true + return c } -func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { +func (f *Firewall) packetTimeout(fp firewall.Packet) time.Duration { var timeout time.Duration - c := &conn{} switch fp.Protocol { case firewall.ProtoTCP: @@ -581,7 +698,27 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { default: timeout = f.DefaultTimeout } + return timeout +} +func (f *Firewall) dupeConn(fp firewall.Packet, c *conn) { + conntrack := f.Conntrack + conntrack.Lock() + if _, ok := conntrack.Conns[fp]; !ok { + conntrack.TimerWheel.Advance(time.Now()) + conntrack.TimerWheel.Add(fp, f.packetTimeout(fp)) + } + + // Record which rulesVersion allowed this connection, so we can retest after + // firewall reload + conntrack.Conns[fp] = c + conntrack.Unlock() +} + +func (f *Firewall) addConn(fp firewall.Packet, incoming bool) *conn { + c := &conn{} + + timeout := f.packetTimeout(fp) conntrack := f.Conntrack conntrack.Lock() if _, ok := conntrack.Conns[fp]; !ok { @@ -596,6 +733,7 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { c.Expires = time.Now().Add(timeout) conntrack.Conns[fp] = c conntrack.Unlock() + return c } // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel diff --git a/firewall_test.go b/firewall_test.go index 1df62a81..1430c3da 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -212,44 +212,44 @@ func TestFirewall_Drop(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr p.RemoteAddr = netip.MustParseAddr("1.2.3.10") - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func TestFirewall_DropV6(t *testing.T) { @@ -291,44 +291,44 @@ func TestFirewall_DropV6(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr p.RemoteAddr = netip.MustParseAddr("fd12::56") - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -536,10 +536,10 @@ func TestFirewall_Drop2(t *testing.T) { cp := cert.NewCAPool() // h1/c1 lacks the proper groups - require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) + require.ErrorIs(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { @@ -617,18 +617,18 @@ func TestFirewall_Drop3(t *testing.T) { cp := cert.NewCAPool() // c1 should pass because host match - require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h2, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) - assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule) // Test a remote address match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) - require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil)) } func TestFirewall_Drop3V6(t *testing.T) { @@ -666,7 +666,7 @@ func TestFirewall_Drop3V6(t *testing.T) { fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) cp := cert.NewCAPool() require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -708,12 +708,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) @@ -722,7 +722,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) @@ -731,7 +731,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { fw.rulesVersion = oldFw.rulesVersion + 1 // Drop outbound because conntrack doesn't match new ruleset - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule) } func TestFirewall_DropIPSpoofing(t *testing.T) { @@ -777,7 +777,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { Protocol: firewall.ProtoUDP, Fragment: false, } - assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrInvalidRemoteIP) } func BenchmarkLookup(b *testing.B) { @@ -1184,7 +1184,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) { t.Helper() cp := cert.NewCAPool() resetConntrack(fw) - err := fw.Drop(c.p, true, c.h, cp, nil) + err := fw.Drop(c.p, nil, true, c.h, cp, nil) if c.err == nil { require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr) } else { diff --git a/inside.go b/inside.go index 3735e470..4ad10e78 100644 --- a/inside.go +++ b/inside.go @@ -49,18 +49,18 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } var hostinfo *HostInfo - var destVpnAddr netip.Addr var ready bool - if fwPacket.RemoteAddr == f.snatMaps.snatIP { - //todo unsnat happens here - destVpnAddr = f.unSnat(packet, fwPacket) //todo bail if we can't unsnat? - } + snatMode := fwPacket.RemoteAddr == f.firewall.snatAddr - if destVpnAddr.IsValid() { - hostinfo, ready = f.getOrHandshakeNoRouting(destVpnAddr, func(hh *HandshakeHostInfo) { - hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) - }) + if snatMode { + //todo unsnat happens here, would be nice to not + destVpnAddr := f.firewall.unSnat(packet, fwPacket, nil, f.pki.GetCAPool()) //todo bail if we can't unsnat? + if destVpnAddr.IsValid() { + hostinfo, ready = f.getOrHandshakeNoRouting(destVpnAddr, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + } //otherwise, hostinfo will be nil } else { //if we didn't need to unsnat hostinfo, ready = f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) @@ -81,10 +81,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, packet, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) - } else { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { @@ -233,7 +232,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) + dropReason := f.firewall.Drop(*fp, nil, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). diff --git a/interface.go b/interface.go index aa9ba50c..9f83d183 100644 --- a/interface.go +++ b/interface.go @@ -50,26 +50,6 @@ type InterfaceConfig struct { l *logrus.Logger } -type SnatMapping struct { - Src netip.AddrPort - SrcVpnIp netip.Addr -} - -type SnatMap struct { - m map[uint16]SnatMapping -} - -func (s *SnatMap) addMapping(src netip.AddrPort) { - -} - -type SnatMaps struct { - TCP SnatMap - UDP SnatMap - ICMP SnatMap //todo index? - snatIP netip.Addr -} - type Interface struct { hostMap *HostMap outside udp.Conn @@ -105,9 +85,8 @@ type Interface struct { conntrackCacheTimeout time.Duration - writers []udp.Conn - readers []io.ReadWriteCloser - snatMaps *SnatMaps //todo this needs some kind of atomic semantics for cross-routine access + writers []udp.Conn + readers []io.ReadWriteCloser metricHandshakes metrics.Histogram messageMetrics *MessageMetrics @@ -184,33 +163,21 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { cs := c.pki.getCertState() ifce := &Interface{ - pki: c.pki, - hostMap: c.HostMap, - outside: c.Outside, - inside: c.Inside, - firewall: c.Firewall, - serveDns: c.ServeDns, - handshakeManager: c.HandshakeManager, - createTime: time.Now(), - lightHouse: c.lightHouse, - dropLocalBroadcast: c.DropLocalBroadcast, - dropMulticast: c.DropMulticast, - routines: c.routines, - version: c.version, - writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), - snatMaps: &SnatMaps{ - TCP: SnatMap{ - m: map[uint16]SnatMapping{}, - }, - UDP: SnatMap{ - m: map[uint16]SnatMapping{}, - }, - ICMP: SnatMap{ - m: map[uint16]SnatMapping{}, - }, - snatIP: srcsnortaddr, //todo this should be source of truthed here - }, + pki: c.pki, + hostMap: c.HostMap, + outside: c.Outside, + inside: c.Inside, + firewall: c.Firewall, + serveDns: c.ServeDns, + handshakeManager: c.HandshakeManager, + createTime: time.Now(), + lightHouse: c.lightHouse, + dropLocalBroadcast: c.DropLocalBroadcast, + dropMulticast: c.DropMulticast, + routines: c.routines, + version: c.version, + writers: make([]udp.Conn, c.routines), + readers: make([]io.ReadWriteCloser, c.routines), myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, myVpnAddrs: cs.myVpnAddrs, diff --git a/outside.go b/outside.go index e5137e48..a1fa44bf 100644 --- a/outside.go +++ b/outside.go @@ -400,85 +400,6 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { return ErrIPv6CouldNotFindPayload } -var srcsnortaddr = netip.MustParseAddr("169.254.55.96") - -func CalculateIPv4Checksum(header []byte) uint16 { - //todo this should be elsewhere - headerLen := int(header[0]&0x0F) * 4 - - if len(header) < headerLen { - return 0 - } - - var sum uint32 - for i := 0; i < headerLen; i += 2 { - word := uint32(binary.BigEndian.Uint16(header[i : i+2])) - sum += word - } - - for sum > 0xFFFF { - sum = (sum & 0xFFFF) + (sum >> 16) - } - - return uint16(^sum) -} - -func recalcIPv4Checksum(data []byte) { - data[10] = 0 - data[11] = 0 - checksum := CalculateIPv4Checksum(data) - binary.BigEndian.PutUint16(data[10:12], checksum) -} - -func (f *Interface) unSnat(data []byte, fp *firewall.Packet) netip.Addr { - var mapping SnatMapping - var ok bool - switch fp.Protocol { - case firewall.ProtoICMP: - //todo hack - mapping, ok = f.snatMaps.ICMP.m[0] - default: - f.l.WithField("fwPacket", fp).Warn("Unsupported unSNAT protocol") - return netip.Addr{} - } - if !ok { - f.l.WithField("fwPacket", fp).Warn("got a snat packet we don't know how to unsnat") - return netip.Addr{} - } - - copy(data[16:], mapping.Src.Addr().AsSlice()) - - recalcIPv4Checksum(data) - return mapping.SrcVpnIp -} - -func (f *Interface) applySnat(data []byte, fp *firewall.Packet, hostinfo *HostInfo) { - if !f.snatMaps.snatIP.Is4() { - return //bad! - } - - //todo math should exist to take existing checksum, old ip, new ip, and set new checksum, right? - - //todo set srcport - //todo record mapping somehow??? sadly the somehow has to be safe/sane across all routines - switch fp.Protocol { - case firewall.ProtoICMP, firewall.ProtoICMPv6: - f.snatMaps.ICMP.m[0] = SnatMapping{ - Src: netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort), - SrcVpnIp: hostinfo.vpnAddrs[0], //todo I hope this is ipv6 - } - case firewall.ProtoTCP: - //todo - case firewall.ProtoUDP: - //also todo - } - - fp.RemoteAddr = f.snatMaps.snatIP - copy(data[12:], f.snatMaps.snatIP.AsSlice()) - - recalcIPv4Checksum(data) -} - func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { // Do we at least have an ipv4 header worth of data? if len(data) < ipv4.HeaderLen { @@ -573,17 +494,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - //todo apply srcsnort here? - //todo rp_filter will need to be set or defeated somehow - if fwPacket.IsIPv4() && hostinfo.HasOnlyV6Addresses() { - if len(f.pki.getCertState().GetDefaultCertificate().UnsafeNetworks()) != 0 { - //todo do not snat if you are not a router for the destination -- for now, just if you're not a router - //f.myVpnNetworksTable.Contains(fwPacket.RemoteAddr) - f.applySnat(out, fwPacket, hostinfo) - } - } - - dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, out, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in diff --git a/snat.go b/snat.go new file mode 100644 index 00000000..a80766e8 --- /dev/null +++ b/snat.go @@ -0,0 +1,72 @@ +package nebula + +import ( + "encoding/binary" + "net/netip" +) + +func CalculateIPv4Checksum(header []byte) uint16 { + //todo this should be elsewhere + headerLen := int(header[0]&0x0F) * 4 + + if len(header) < headerLen { + return 0 + } + + var sum uint32 + for i := 0; i < headerLen; i += 2 { + word := uint32(binary.BigEndian.Uint16(header[i : i+2])) + sum += word + } + + for sum > 0xFFFF { + sum = (sum & 0xFFFF) + (sum >> 16) + } + + return uint16(^sum) +} + +func recalcIPv4Checksum(data []byte) { + data[10] = 0 + data[11] = 0 + checksum := CalculateIPv4Checksum(data) + binary.BigEndian.PutUint16(data[10:12], checksum) +} + +func CalcNewUDPChecksum(oldChecksum uint16, oldSrcIP, newSrcIP netip.Addr, oldSrcPort, newSrcPort uint16) uint16 { + // Convert IPs to uint32 + oldIP := binary.BigEndian.Uint32(oldSrcIP.AsSlice()) + newIP := binary.BigEndian.Uint32(newSrcIP.AsSlice()) + + // Start with inverted checksum + checksum := uint32(^oldChecksum) + + // Subtract old IP (as two 16-bit words) + checksum += uint32(^uint16(oldIP >> 16)) + checksum += uint32(^uint16(oldIP & 0xFFFF)) + + // Subtract old port + checksum += uint32(^oldSrcPort) + + // Add new IP (as two 16-bit words) + checksum += uint32(newIP >> 16) + checksum += uint32(newIP & 0xFFFF) + + // Add new port + checksum += uint32(newSrcPort) + + // Fold carries + for checksum > 0xFFFF { + checksum = (checksum & 0xFFFF) + (checksum >> 16) + } + + // Return ones' complement + return ^uint16(checksum) +} + +func recalcUDPv4Checksum(data []byte, oldSrcIP, newSrcIP netip.Addr, oldSrcPort, newSrcPort uint16) { + const UDPChecksumOffset = 20 + 6 //todo pls no options pls, big bad stupid hack + oldcsum := binary.BigEndian.Uint16(data[UDPChecksumOffset : UDPChecksumOffset+2]) + checksum := CalcNewUDPChecksum(oldcsum, oldSrcIP, newSrcIP, oldSrcPort, newSrcPort) + binary.BigEndian.PutUint16(data[UDPChecksumOffset:UDPChecksumOffset+2], checksum) +} From a9a9e58440b1ce52c4195d3f5c7a937481e14a22 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 15 Jan 2026 10:41:02 -0600 Subject: [PATCH 05/16] tcp "works" --- firewall.go | 18 ++++++------------ snat.go | 7 +++++++ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/firewall.go b/firewall.go index 3ec512d6..956ce6e8 100644 --- a/firewall.go +++ b/firewall.go @@ -438,12 +438,16 @@ func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn, caPool *cer return netip.Addr{} } + //change dst IP copy(data[16:], c.snat.Src.Addr().AsSlice()) recalcIPv4Checksum(data) switch fp.Protocol { case firewall.ProtoUDP: recalcUDPv4Checksum(data, f.snatAddr, c.snat.Src.Addr(), fp.RemotePort, c.snat.Src.Port()) + case firewall.ProtoTCP: + recalcTCPv4Checksum(data, f.snatAddr, c.snat.Src.Addr(), fp.RemotePort, c.snat.Src.Port()) + } return c.snat.SrcVpnIp } @@ -452,21 +456,11 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo //todo math should exist to take existing checksum, old ip, new ip, and set new checksum, right? //todo set srcport - //todo record mapping somehow??? sadly the somehow has to be safe/sane across all routines - //switch fp.Protocol { - //case firewall.ProtoICMP: - // - // - //case firewall.ProtoTCP: - // //todo - //case firewall.ProtoUDP: - // //also todo - //} - // c.snat.Src = netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort) c.snat.SrcVpnIp = hostinfo.vpnAddrs[0] //todo I hope this is ipv6 fp.RemoteAddr = f.snatAddr + //change src IP copy(data[12:], f.snatAddr.AsSlice()) recalcIPv4Checksum(data) @@ -475,7 +469,7 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo //todo change the port recalcUDPv4Checksum(data, c.snat.Src.Addr(), f.snatAddr, c.snat.Src.Port(), fp.RemotePort) case firewall.ProtoTCP: - //todo recalc checksum + recalcTCPv4Checksum(data, c.snat.Src.Addr(), f.snatAddr, c.snat.Src.Port(), fp.RemotePort) } } diff --git a/snat.go b/snat.go index a80766e8..d52493e0 100644 --- a/snat.go +++ b/snat.go @@ -70,3 +70,10 @@ func recalcUDPv4Checksum(data []byte, oldSrcIP, newSrcIP netip.Addr, oldSrcPort, checksum := CalcNewUDPChecksum(oldcsum, oldSrcIP, newSrcIP, oldSrcPort, newSrcPort) binary.BigEndian.PutUint16(data[UDPChecksumOffset:UDPChecksumOffset+2], checksum) } + +func recalcTCPv4Checksum(data []byte, oldSrcIP, newSrcIP netip.Addr, oldSrcPort, newSrcPort uint16) { + const TCPChecksumOffset = 20 + 16 //todo pls no options pls, big bad stupid hack + oldcsum := binary.BigEndian.Uint16(data[TCPChecksumOffset : TCPChecksumOffset+2]) + checksum := CalcNewUDPChecksum(oldcsum, oldSrcIP, newSrcIP, oldSrcPort, newSrcPort) + binary.BigEndian.PutUint16(data[TCPChecksumOffset:TCPChecksumOffset+2], checksum) +} From b10f977f6d804582a0c6e87c2f8efab446620952 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 15 Jan 2026 11:43:39 -0600 Subject: [PATCH 06/16] more efficient checksumming --- firewall.go | 26 +++++++++++++------------ snat.go | 56 ++++++++++++++++++----------------------------------- 2 files changed, 33 insertions(+), 49 deletions(-) diff --git a/firewall.go b/firewall.go index 956ce6e8..47c55502 100644 --- a/firewall.go +++ b/firewall.go @@ -438,16 +438,18 @@ func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn, caPool *cer return netip.Addr{} } + oldIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort) + //change dst IP copy(data[16:], c.snat.Src.Addr().AsSlice()) - recalcIPv4Checksum(data) + recalcIPv4Checksum(data, oldIP.Addr(), c.snat.Src.Addr()) + switch fp.Protocol { case firewall.ProtoUDP: - recalcUDPv4Checksum(data, f.snatAddr, c.snat.Src.Addr(), fp.RemotePort, c.snat.Src.Port()) + recalcUDPv4Checksum(data, oldIP, c.snat.Src) case firewall.ProtoTCP: - recalcTCPv4Checksum(data, f.snatAddr, c.snat.Src.Addr(), fp.RemotePort, c.snat.Src.Port()) - + recalcTCPv4Checksum(data, oldIP, c.snat.Src) } return c.snat.SrcVpnIp } @@ -455,6 +457,7 @@ func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn, caPool *cer func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) { //todo math should exist to take existing checksum, old ip, new ip, and set new checksum, right? + newIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort) //todo actually change remoteport //todo set srcport c.snat.Src = netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort) c.snat.SrcVpnIp = hostinfo.vpnAddrs[0] //todo I hope this is ipv6 @@ -463,23 +466,19 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo //change src IP copy(data[12:], f.snatAddr.AsSlice()) - recalcIPv4Checksum(data) + recalcIPv4Checksum(data, c.snat.Src.Addr(), newIP.Addr()) switch fp.Protocol { case firewall.ProtoUDP: //todo change the port - recalcUDPv4Checksum(data, c.snat.Src.Addr(), f.snatAddr, c.snat.Src.Port(), fp.RemotePort) + recalcUDPv4Checksum(data, c.snat.Src, newIP) case firewall.ProtoTCP: - recalcTCPv4Checksum(data, c.snat.Src.Addr(), f.snatAddr, c.snat.Src.Port(), fp.RemotePort) + recalcTCPv4Checksum(data, c.snat.Src, newIP) } } // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { - // Check if we spoke to this tuple, if we did then allow this packet - c := f.inConns(fp, h, caPool, localCache) - //can't return yet, need to snat maybe - var err error specialSnatMode := false @@ -488,8 +487,11 @@ func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostIn table = f.InRules } + // Check if we spoke to this tuple, if we did then allow this packet + c := f.inConns(fp, h, caPool, localCache) if c != nil { - specialSnatMode = fp.IsIPv4() && h.HasOnlyV6Addresses() //todo? + //can't return yet, need to snat maybe + specialSnatMode = fp.IsIPv4() && h.HasOnlyV6Addresses() //todo I wish I only set this once somehow goto snat } diff --git a/snat.go b/snat.go index d52493e0..e9c7a804 100644 --- a/snat.go +++ b/snat.go @@ -5,36 +5,14 @@ import ( "net/netip" ) -func CalculateIPv4Checksum(header []byte) uint16 { - //todo this should be elsewhere - headerLen := int(header[0]&0x0F) * 4 - - if len(header) < headerLen { - return 0 - } - - var sum uint32 - for i := 0; i < headerLen; i += 2 { - word := uint32(binary.BigEndian.Uint16(header[i : i+2])) - sum += word - } - - for sum > 0xFFFF { - sum = (sum & 0xFFFF) + (sum >> 16) - } - - return uint16(^sum) -} - -func recalcIPv4Checksum(data []byte) { - data[10] = 0 - data[11] = 0 - checksum := CalculateIPv4Checksum(data) +func recalcIPv4Checksum(data []byte, oldSrcIP netip.Addr, newSrcIP netip.Addr) { + oldChecksum := binary.BigEndian.Uint16(data[10:12]) + //because of how checksums work, we can re-use this function + checksum := calcNewTransportChecksum(oldChecksum, oldSrcIP, 0, newSrcIP, 0) binary.BigEndian.PutUint16(data[10:12], checksum) } -func CalcNewUDPChecksum(oldChecksum uint16, oldSrcIP, newSrcIP netip.Addr, oldSrcPort, newSrcPort uint16) uint16 { - // Convert IPs to uint32 +func calcNewTransportChecksum(oldChecksum uint16, oldSrcIP netip.Addr, oldSrcPort uint16, newSrcIP netip.Addr, newSrcPort uint16) uint16 { oldIP := binary.BigEndian.Uint32(oldSrcIP.AsSlice()) newIP := binary.BigEndian.Uint32(newSrcIP.AsSlice()) @@ -64,16 +42,20 @@ func CalcNewUDPChecksum(oldChecksum uint16, oldSrcIP, newSrcIP netip.Addr, oldSr return ^uint16(checksum) } -func recalcUDPv4Checksum(data []byte, oldSrcIP, newSrcIP netip.Addr, oldSrcPort, newSrcPort uint16) { - const UDPChecksumOffset = 20 + 6 //todo pls no options pls, big bad stupid hack - oldcsum := binary.BigEndian.Uint16(data[UDPChecksumOffset : UDPChecksumOffset+2]) - checksum := CalcNewUDPChecksum(oldcsum, oldSrcIP, newSrcIP, oldSrcPort, newSrcPort) - binary.BigEndian.PutUint16(data[UDPChecksumOffset:UDPChecksumOffset+2], checksum) +func recalcV4TransportChecksum(offsetInsideHeader int, data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) { + ipHeaderOffset := int(data[0]&0x0F) * 4 + offset := ipHeaderOffset + offsetInsideHeader + oldcsum := binary.BigEndian.Uint16(data[offset : offset+2]) + checksum := calcNewTransportChecksum(oldcsum, oldSrcIP.Addr(), oldSrcIP.Port(), newSrcIP.Addr(), newSrcIP.Port()) + binary.BigEndian.PutUint16(data[offset:offset+2], checksum) } -func recalcTCPv4Checksum(data []byte, oldSrcIP, newSrcIP netip.Addr, oldSrcPort, newSrcPort uint16) { - const TCPChecksumOffset = 20 + 16 //todo pls no options pls, big bad stupid hack - oldcsum := binary.BigEndian.Uint16(data[TCPChecksumOffset : TCPChecksumOffset+2]) - checksum := CalcNewUDPChecksum(oldcsum, oldSrcIP, newSrcIP, oldSrcPort, newSrcPort) - binary.BigEndian.PutUint16(data[TCPChecksumOffset:TCPChecksumOffset+2], checksum) +func recalcUDPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) { + const offsetInsideHeader = 6 + recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP) +} + +func recalcTCPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) { + const offsetInsideHeader = 16 + recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP) } From 2fd3da197a8e3a2b1d9931571b7290edcf654b59 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 15 Jan 2026 12:10:05 -0600 Subject: [PATCH 07/16] autoconfigure the snat return route for the device, if needed --- main.go | 3 ++- overlay/tun.go | 10 ++++---- overlay/tun_linux.go | 61 ++++++++++++++++++++++++++++++++++---------- 3 files changed, 55 insertions(+), 19 deletions(-) diff --git a/main.go b/main.go index 7b326616..5a37eacd 100644 --- a/main.go +++ b/main.go @@ -135,7 +135,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg deviceFactory = overlay.NewDeviceFromConfig } - tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines) + cs := pki.getCertState() + tun, err = deviceFactory(c, l, cs.myVpnNetworks, cs.GetDefaultCertificate().UnsafeNetworks(), routines) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } diff --git a/overlay/tun.go b/overlay/tun.go index 3a61d186..35adbcf3 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -13,22 +13,22 @@ import ( const DefaultMTU = 1300 // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) return tun, nil default: - return newTun(c, l, vpnNetworks, routines > 1) + return newTun(c, l, vpnNetworks, unsafeNetworks, routines > 1) } } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { - return newTunFromFd(c, l, *fd, vpnNetworks) + return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) { + return newTunFromFd(c, l, *fd, vpnNetworks, unsafeNetworks) } } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 32bf51f5..35c50f9e 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -25,14 +25,15 @@ import ( type tun struct { io.ReadWriteCloser - fd int - Device string - vpnNetworks []netip.Prefix - MaxMTU int - DefaultMTU int - TXQueueLen int - deviceIndex int - ioctlFd uintptr + fd int + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + MaxMTU int + DefaultMTU int + TXQueueLen int + deviceIndex int + ioctlFd uintptr Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] @@ -65,10 +66,10 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks) if err != nil { return nil, err } @@ -78,7 +79,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -113,7 +114,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks) if err != nil { return nil, err } @@ -123,11 +124,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), @@ -409,6 +411,27 @@ func (t *tun) setMTU() { } } +func (t *tun) setSnatRoute() error { + snataddr := netip.MustParsePrefix("169.254.55.96/32") //todo get this from elsewhere? Or maybe we should pick it, and feed it back out to the firewall? + dr := &net.IPNet{ + IP: snataddr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(snataddr.Bits(), snataddr.Addr().BitLen()), + } + + nr := netlink.Route{ + LinkIndex: t.deviceIndex, + Dst: dr, + //todo do we need these other options? + //MTU: t.DefaultMTU, + //AdvMSS: t.advMSS(Route{}), + Scope: unix.RT_SCOPE_LINK, + //Protocol: unix.RTPROT_KERNEL, + Table: unix.RT_TABLE_MAIN, + Type: unix.RTN_UNICAST, + } + return netlink.RouteReplace(&nr) +} + func (t *tun) setDefaultRoute(cidr netip.Prefix) error { dr := &net.IPNet{ IP: cidr.Masked().Addr().AsSlice(), @@ -485,6 +508,18 @@ func (t *tun) addRoutes(logErrors bool) error { } } + onlyV6Addresses := false + for _, n := range t.vpnNetworks { + if n.Addr().Is6() { + onlyV6Addresses = true + break + } + } + + if len(t.unsafeNetworks) != 0 && onlyV6Addresses { + return t.setSnatRoute() + } + return nil } From 96a5c258f8d035ae361d87d39343b19f5a910376 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 15 Jan 2026 13:02:20 -0600 Subject: [PATCH 08/16] snat ports! --- firewall.go | 75 +++++++++++++++++++++++++++++++++++++++++------------ inside.go | 4 +-- 2 files changed, 61 insertions(+), 18 deletions(-) diff --git a/firewall.go b/firewall.go index 47c55502..3e4edb70 100644 --- a/firewall.go +++ b/firewall.go @@ -2,6 +2,7 @@ package nebula import ( "crypto/sha256" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -27,8 +28,12 @@ type FirewallInterface interface { } type snatInfo struct { - Src netip.AddrPort + //Src is the source IP+port to write into unsafe-route-bound packet + Src netip.AddrPort + //SrcVpnIp is the overlay IP associated with this flow. It's needed to associate reply traffic so we can get it back to the right host. SrcVpnIp netip.Addr + //SnatPort is the port to rewrite into an overlay-bound packet + SnatPort uint16 } func (s *snatInfo) Valid() bool { @@ -167,12 +172,14 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D tmax = defaultTimeout } + hasV4Networks := false routableNetworks := new(bart.Lite) var assignedNetworks []netip.Prefix for _, network := range c.Networks() { nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) routableNetworks.Insert(nprefix) assignedNetworks = append(assignedNetworks, network) + hasV4Networks = hasV4Networks || network.Addr().Is4() } hasUnsafeNetworks := false @@ -182,8 +189,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } snatAddr := netip.Addr{} - if hasUnsafeNetworks { - snatAddr = netip.MustParseAddr("169.254.55.96") + if hasUnsafeNetworks && !hasV4Networks { + snatAddr = netip.MustParseAddr("169.254.55.96") //todo this needs to come from the config, or perhaps the tun } return &Firewall{ @@ -424,12 +431,11 @@ var ErrPeerRejected = errors.New("remote address is not within a network that we var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks") var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses") var ErrNoMatchingRule = errors.New("no matching rule in firewall table") -var ErrSnatRequired = errors.New("snat required to pass traffic") -func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn, caPool *cert.CAPool) netip.Addr { +func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn) netip.Addr { if c == nil { //unfortunately this needs to lock. Surely there's a better way, but I need to make this flow at all first. - c = f.peek(*fp, nil, caPool, nil) + c = f.peek(*fp, nil) } if c == nil { return netip.Addr{} @@ -442,36 +448,73 @@ func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn, caPool *cer //change dst IP copy(data[16:], c.snat.Src.Addr().AsSlice()) - recalcIPv4Checksum(data, oldIP.Addr(), c.snat.Src.Addr()) + ipHeaderLen := int(data[0]&0x0F) * 4 + //dst port is at offset 2 + dstport := ipHeaderLen + 2 switch fp.Protocol { case firewall.ProtoUDP: + binary.BigEndian.PutUint16(data[dstport:dstport+2], c.snat.Src.Port()) recalcUDPv4Checksum(data, oldIP, c.snat.Src) case firewall.ProtoTCP: + binary.BigEndian.PutUint16(data[dstport:dstport+2], c.snat.Src.Port()) recalcTCPv4Checksum(data, oldIP, c.snat.Src) } return c.snat.SrcVpnIp } func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) { - //todo math should exist to take existing checksum, old ip, new ip, and set new checksum, right? - - newIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort) //todo actually change remoteport //todo set srcport - c.snat.Src = netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort) - c.snat.SrcVpnIp = hostinfo.vpnAddrs[0] //todo I hope this is ipv6 - fp.RemoteAddr = f.snatAddr + if c.snat.Valid() { + //old flow + fp.RemoteAddr = f.snatAddr + fp.RemotePort = c.snat.SnatPort + } else if hostinfo.vpnAddrs[0].Is6() { + //we got a new flow + c.snat.Src = netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort) + c.snat.SrcVpnIp = hostinfo.vpnAddrs[0] + fp.RemoteAddr = f.snatAddr + + //find a new port to use, if needed + for { + existingFlow := f.peek(*fp, nil) //locking and unlocking for each peek is slow, but simple for now + if existingFlow == nil { + break //yay, we can use this port + } + //increment and retry. There's probably better strategies out there + fp.RemotePort++ + if fp.RemotePort < 0x7ff { + fp.RemotePort += 0x7ff // keep it ephemeral for now + } //todo if we're totally out of ports this loops forever. Probably not good. + } + c.snat.SnatPort = fp.RemotePort + } else { + f.l.WithFields(logrus.Fields{ + "fp": *fp, + "conn": *c, + "hostinfo": hostinfo, + }).Error("this packet cannot be snatted") + return + } + + newIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort) //change src IP copy(data[12:], f.snatAddr.AsSlice()) - recalcIPv4Checksum(data, c.snat.Src.Addr(), newIP.Addr()) + ipHeaderLen := int(data[0]&0x0F) * 4 + switch fp.Protocol { + case firewall.ProtoICMP: + //todo! case firewall.ProtoUDP: - //todo change the port + //src port is at offset 0 + binary.BigEndian.PutUint16(data[ipHeaderLen:ipHeaderLen+2], c.snat.SnatPort) recalcUDPv4Checksum(data, c.snat.Src, newIP) case firewall.ProtoTCP: + //src port is at offset 0 + binary.BigEndian.PutUint16(data[ipHeaderLen:ipHeaderLen+2], c.snat.SnatPort) recalcTCPv4Checksum(data, c.snat.Src, newIP) } } @@ -585,7 +628,7 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) peek(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn { +func (f *Firewall) peek(fp firewall.Packet, localCache firewall.ConntrackCache) *conn { //big todo, this cache needs to know snat info as well //if localCache != nil { // if out, ok := localCache[fp]; ok { diff --git a/inside.go b/inside.go index 4ad10e78..a8a0f86d 100644 --- a/inside.go +++ b/inside.go @@ -51,11 +51,11 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet var hostinfo *HostInfo var ready bool - snatMode := fwPacket.RemoteAddr == f.firewall.snatAddr + snatMode := f.firewall.snatAddr.IsValid() && fwPacket.RemoteAddr == f.firewall.snatAddr if snatMode { //todo unsnat happens here, would be nice to not - destVpnAddr := f.firewall.unSnat(packet, fwPacket, nil, f.pki.GetCAPool()) //todo bail if we can't unsnat? + destVpnAddr := f.firewall.unSnat(packet, fwPacket, nil) //todo bail if we can't unsnat? if destVpnAddr.IsValid() { hostinfo, ready = f.getOrHandshakeNoRouting(destVpnAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) From c984cbe673078e8df0e643ae85f09b9eef071c37 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 15 Jan 2026 13:31:42 -0600 Subject: [PATCH 09/16] restore conntrack cache --- firewall.go | 41 +++++++++++++++++------------------------ 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/firewall.go b/firewall.go index 3e4edb70..b3a63150 100644 --- a/firewall.go +++ b/firewall.go @@ -522,8 +522,7 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { - var err error - specialSnatMode := false + specialSnatMode := f.hasUnsafeNetworks && fp.IsIPv4() && h.HasOnlyV6Addresses() //todo I wish I only set this once somehow table := f.OutRules if incoming { @@ -531,10 +530,15 @@ func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostIn } // Check if we spoke to this tuple, if we did then allow this packet + // Check the cache first, iff not snatting + if localCache != nil && !specialSnatMode { + if _, ok := localCache[fp]; ok { + return nil //packet matched the cache, we're not snatting, we can return early! + } + } c := f.inConns(fp, h, caPool, localCache) if c != nil { //can't return yet, need to snat maybe - specialSnatMode = fp.IsIPv4() && h.HasOnlyV6Addresses() //todo I wish I only set this once somehow goto snat } @@ -542,12 +546,11 @@ func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostIn if h.networks == nil { // Simple case: Certificate has one address and no unsafe networks if h.vpnAddrs[0] != fp.RemoteAddr { - specialSnatMode = fp.IsIPv4() && h.HasOnlyV6Addresses() if !specialSnatMode { f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP //todo! - } - } + } //else we're in special snat mode, and we need to apply more checks below + } //else? all good, fall through } else { //todo check for srcsnortaddr here too? nwType, ok := h.networks.Lookup(fp.RemoteAddr) @@ -559,10 +562,12 @@ func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostIn case NetworkTypeVPN: break // nothing special case NetworkTypeVPNPeer: + //todo we might need a specialSnatMode case in here to handle routers with v4 addresses when we don't also have a v4 address? f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrPeerRejected // reject for now, one day this may have different FW rules case NetworkTypeUnsafe: - specialSnatMode = fp.IsIPv4() && h.HasOnlyV6Addresses() + //intentionally excluding f.hasUnsafeNetworks -- this is what lets routers talk back to us with our unsafe traffic! + specialSnatMode = fp.IsIPv4() && h.HasOnlyV6Addresses() && f.assignedNetworks[0].Addr().Is6() break // nothing special, one day this may have different FW rules default: f.metrics(incoming).droppedRemoteAddr.Inc(1) @@ -588,20 +593,14 @@ func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostIn snat: if incoming { - //todo rp_filter will need to be set or defeated somehow if specialSnatMode { - if f.hasUnsafeNetworks { - //todo do not snat if you are not a router for the destination -- for now, just if you're not a router - //f.myVpnNetworksTable.Contains(fwPacket.RemoteAddr) - f.applySnat(pkt, &fp, c, h) - f.dupeConn(fp, c) - } + //todo do not snat if you are not a router for the destination -- for now, just if you're not a router + f.applySnat(pkt, &fp, c, h) + f.dupeConn(fp, c) //track the snatted flow with the same expiration as the unsnatted version } - } else { //outgoing + } //outgoing snat is handled before this function is called (for now!) - } - - return err + return nil } func (f *Firewall) metrics(incoming bool) firewallMetrics { @@ -651,12 +650,6 @@ func (f *Firewall) peek(fp firewall.Packet, localCache firewall.ConntrackCache) } func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn { - //big todo, this cache needs to know snat info as well - //if localCache != nil { - // if out, ok := localCache[fp]; ok { - // return out - // } - //} conntrack := f.Conntrack conntrack.Lock() From bde961a16432fb6e153f8f04c2249152ea8365b0 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 15 Jan 2026 13:33:18 -0600 Subject: [PATCH 10/16] fix newTun signatures --- overlay/tun_android.go | 2 +- overlay/tun_darwin.go | 2 +- overlay/tun_freebsd.go | 2 +- overlay/tun_ios.go | 2 +- overlay/tun_netbsd.go | 2 +- overlay/tun_openbsd.go | 2 +- overlay/tun_tester.go | 2 +- overlay/tun_windows.go | 2 +- overlay/user.go | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/overlay/tun_android.go b/overlay/tun_android.go index eddef882..0c30fb52 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 128c2001..1ec9e5db 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -79,7 +79,7 @@ type ifreqAlias6 struct { Lifetime addrLifetime } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 8d292263..e946b815 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -203,7 +203,7 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var fd int var err error diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 0ce01df8..cae2fdd1 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -28,7 +28,7 @@ type tun struct { l *logrus.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 2986c895..1336ccfb 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -74,7 +74,7 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 9209b795..6c67c8f0 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -67,7 +67,7 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, return nil, fmt.Errorf("newTunFromFd not supported in openbsd") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3477de3d..e113b59d 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -28,7 +28,7 @@ type TestTun struct { TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*TestTun, error) { _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index b4d78b66..06af7e83 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -42,7 +42,7 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Devic return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*winTun, error) { err := checkWinTunExists() if err != nil { return nil, fmt.Errorf("can not load the wintun driver: %w", err) diff --git a/overlay/user.go b/overlay/user.go index 1f92d4e9..52fa0df7 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -9,7 +9,7 @@ import ( "github.com/slackhq/nebula/routing" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) { return NewUserDevice(vpnNetworks) } From f9921bcf9b3357a01ff5f03348d78310c34aea08 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 15 Jan 2026 13:45:44 -0600 Subject: [PATCH 11/16] don't need that anymore --- firewall.go | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/firewall.go b/firewall.go index b3a63150..0221ca29 100644 --- a/firewall.go +++ b/firewall.go @@ -99,11 +99,6 @@ type FirewallConntrack struct { Conns map[firewall.Packet]*conn TimerWheel *TimerWheel[firewall.Packet] - // SNATFlows maps protocol->source_port->original packet info for unsnatting. - // the srcport to use for outgoing snat flows is stored in Conns. - // When a flow is expired from Conns, it needs to be removed from SNATFlows as well. - // todo if we put "both" keys into Conns, we can potentially avoid this problem - SNATFlows map[int]map[uint16]snatInfo } // FirewallTable is the entry point for a rule, the evaluation order is: @@ -435,7 +430,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn) netip.Addr { if c == nil { //unfortunately this needs to lock. Surely there's a better way, but I need to make this flow at all first. - c = f.peek(*fp, nil) + c = f.peek(*fp) } if c == nil { return netip.Addr{} @@ -479,7 +474,7 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo //find a new port to use, if needed for { - existingFlow := f.peek(*fp, nil) //locking and unlocking for each peek is slow, but simple for now + existingFlow := f.peek(*fp) //locking and unlocking for each peek is slow, but simple for now if existingFlow == nil { break //yay, we can use this port } @@ -627,13 +622,7 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) peek(fp firewall.Packet, localCache firewall.ConntrackCache) *conn { - //big todo, this cache needs to know snat info as well - //if localCache != nil { - // if out, ok := localCache[fp]; ok { - // return out - // } - //} +func (f *Firewall) peek(fp firewall.Packet) *conn { conntrack := f.Conntrack conntrack.Lock() From 84ddac0dedbdaa20847a6e759622802624751082 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 15 Jan 2026 14:10:10 -0600 Subject: [PATCH 12/16] don't need that anymore --- firewall.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/firewall.go b/firewall.go index 0221ca29..7cfc21f1 100644 --- a/firewall.go +++ b/firewall.go @@ -730,8 +730,6 @@ func (f *Firewall) dupeConn(fp firewall.Packet, c *conn) { conntrack.TimerWheel.Add(fp, f.packetTimeout(fp)) } - // Record which rulesVersion allowed this connection, so we can retest after - // firewall reload conntrack.Conns[fp] = c conntrack.Unlock() } From 3bc46a06478e71f6eabfd334959335493d8e7508 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 16 Jan 2026 11:26:49 -0600 Subject: [PATCH 13/16] fixy fixy --- overlay/tun_android.go | 2 +- overlay/tun_darwin.go | 2 +- overlay/tun_freebsd.go | 2 +- overlay/tun_ios.go | 2 +- overlay/tun_netbsd.go | 2 +- overlay/tun_openbsd.go | 2 +- overlay/tun_tester.go | 2 +- overlay/tun_windows.go | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 0c30fb52..f091772a 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -26,7 +26,7 @@ type tun struct { l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 1ec9e5db..c9c3927e 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index e946b815..939e0569 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -199,7 +199,7 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index cae2fdd1..85466d1e 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -32,7 +32,7 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ vpnNetworks: vpnNetworks, diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 1336ccfb..39336108 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -70,7 +70,7 @@ type tun struct { var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 6c67c8f0..701d97dd 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -63,7 +63,7 @@ type tun struct { var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in openbsd") } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index e113b59d..145eccb9 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -49,7 +49,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNet }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 06af7e83..18ee533f 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -38,7 +38,7 @@ type winTun struct { tun *wintun.NativeTun } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } From 5fb407e92953e4c41bc044266a3cfd7585888758 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 21 Jan 2026 12:54:42 -0600 Subject: [PATCH 14/16] conntrack ICMPv4 based on code and identifier, without changing firewall filter behavior --- firewall.go | 7 +++ firewall/packet.go | 6 +- firewall_test.go | 139 +++++++++++++++++++++++++++++++++++++++++++++ outside.go | 25 +++++--- routing/balance.go | 7 ++- 5 files changed, 175 insertions(+), 9 deletions(-) diff --git a/firewall.go b/firewall.go index 7cfc21f1..3ddd85bf 100644 --- a/firewall.go +++ b/firewall.go @@ -830,6 +830,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer var port int32 + if p.Protocol == firewall.ProtoICMP { + // port numbers are re-used for connection tracking and SNAT, + // but we don't want to actually filter on them for ICMP + // ICMP6 is omitted because we don't attempt to parse code/identifier/etc out of ICMP6 + return fp[firewall.PortAny].match(p, c, caPool) + } + if p.Fragment { port = firewall.PortFragment } else if incoming { diff --git a/firewall/packet.go b/firewall/packet.go index ade2ad55..8482d476 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -22,7 +22,11 @@ const ( type Packet struct { LocalAddr netip.Addr RemoteAddr netip.Addr - LocalPort uint16 + // LocalPort is the destination port for incoming traffic, or the source port for outgoing. + // For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier + LocalPort uint16 + // RemotePort is the source port for incoming traffic, or the destination port for outgoing. + // For ICMP, it's "code". //todo also store "type?" would need to decode replies, which sucks RemotePort uint16 Protocol uint8 Fragment bool diff --git a/firewall_test.go b/firewall_test.go index 1430c3da..a1f5cacf 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -734,6 +734,145 @@ func TestFirewall_DropConntrackReload(t *testing.T) { assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule) } +func TestFirewall_ICMPPortBehavior(t *testing.T) { + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) + + network := netip.MustParsePrefix("1.2.3.4/24") + + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + }, + InvertedGroups: map[string]struct{}{"default-group": {}}, + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c, + }, + vpnAddrs: []netip.Addr{network.Addr()}, + } + h.buildNetworks(myVpnNetworksTable, c.Certificate) + + cp := cert.NewCAPool() + + templ := firewall.Packet{ + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), + Protocol: firewall.ProtoICMP, + Fragment: false, + } + + t.Run("ICMP allowed", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) + + t.Run("nonzero ports", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) + }) + + t.Run("Any proto, some ports allowed", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + }) + + t.Run("nonzero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + }) + + t.Run("nonzero, matching ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 80 + p.RemotePort = 80 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + }) + }) + t.Run("Any proto, any port", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) + + t.Run("nonzero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) + }) + +} + func TestFirewall_DropIPSpoofing(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} diff --git a/outside.go b/outside.go index a1fa44bf..572f1cc8 100644 --- a/outside.go +++ b/outside.go @@ -329,7 +329,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { switch proto { case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader: fp.Protocol = uint8(proto) - fp.RemotePort = 0 + fp.RemotePort = 0 //we don't attempt to parse ICMPv6 because we don't SNAT it fp.LocalPort = 0 fp.Fragment = false return nil @@ -434,22 +434,33 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { if incoming { fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16]) fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20]) - if fp.Fragment || fp.Protocol == firewall.ProtoICMP { + if fp.Fragment { fp.RemotePort = 0 fp.LocalPort = 0 + } else if fp.Protocol == firewall.ProtoICMP { + //todo remove comment + //icmpType := data[ihl] + //icmpCode := data[ihl+1] + //icmpChecksum := data[ihl+2 : ihl+4] + //icmpIdentifier := data[ihl+4 : ihl+6] + fp.RemotePort = uint16(data[ihl+1]) //code + fp.LocalPort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier } else { - fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) + fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port + fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port } } else { fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16]) fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20]) - if fp.Fragment || fp.Protocol == firewall.ProtoICMP { + if fp.Fragment { fp.RemotePort = 0 fp.LocalPort = 0 + } else if fp.Protocol == firewall.ProtoICMP { + fp.RemotePort = uint16(data[ihl+1]) //code + fp.LocalPort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier } else { - fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) + fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port } } diff --git a/routing/balance.go b/routing/balance.go index 6f524970..22459113 100644 --- a/routing/balance.go +++ b/routing/balance.go @@ -12,7 +12,12 @@ import ( // - https://github.com/skeeto/hash-prospector // [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501 func hashPacket(p *firewall.Packet) int { - x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort) + var x uint32 + if p.Protocol == firewall.ProtoICMP { + x = 0 //Don't attempt to use ICMP's code/id/etc to balance + } else { + x = (uint32(p.LocalPort) << 16) | uint32(p.RemotePort) + } x ^= x >> 16 x *= 0x21f0aaad x ^= x >> 15 From 2c30c2edb93e9d116c3a4832684b85fc9bcde4eb Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 21 Jan 2026 13:22:59 -0600 Subject: [PATCH 15/16] SNAT the ICMPv4 identifier --- firewall.go | 12 ++++++++---- firewall/packet.go | 4 ++-- outside.go | 17 ++++++----------- snat.go | 30 ++++++++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 17 deletions(-) diff --git a/firewall.go b/firewall.go index 3ddd85bf..1ff08d86 100644 --- a/firewall.go +++ b/firewall.go @@ -449,6 +449,10 @@ func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn) netip.Addr dstport := ipHeaderLen + 2 switch fp.Protocol { + case firewall.ProtoICMP: + binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], c.snat.Src.Port()) + icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on this yet (but Linux would) + recalcICMPv4Checksum(data, icmpCode, icmpCode, c.snat.SnatPort, c.snat.Src.Port()) case firewall.ProtoUDP: binary.BigEndian.PutUint16(data[dstport:dstport+2], c.snat.Src.Port()) recalcUDPv4Checksum(data, oldIP, c.snat.Src) @@ -460,7 +464,6 @@ func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn) netip.Addr } func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) { - //todo set srcport if c.snat.Valid() { //old flow fp.RemoteAddr = f.snatAddr @@ -471,7 +474,6 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo c.snat.SrcVpnIp = hostinfo.vpnAddrs[0] fp.RemoteAddr = f.snatAddr - //find a new port to use, if needed for { existingFlow := f.peek(*fp) //locking and unlocking for each peek is slow, but simple for now @@ -494,7 +496,7 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo return } - newIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort) + newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort) //change src IP copy(data[12:], f.snatAddr.AsSlice()) recalcIPv4Checksum(data, c.snat.Src.Addr(), newIP.Addr()) @@ -502,7 +504,9 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo switch fp.Protocol { case firewall.ProtoICMP: - //todo! + binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], c.snat.SnatPort) + icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on this yet (but Linux would) + recalcICMPv4Checksum(data, icmpCode, icmpCode, c.snat.Src.Port(), c.snat.SnatPort) case firewall.ProtoUDP: //src port is at offset 0 binary.BigEndian.PutUint16(data[ipHeaderLen:ipHeaderLen+2], c.snat.SnatPort) diff --git a/firewall/packet.go b/firewall/packet.go index 8482d476..ce00129c 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -23,10 +23,10 @@ type Packet struct { LocalAddr netip.Addr RemoteAddr netip.Addr // LocalPort is the destination port for incoming traffic, or the source port for outgoing. - // For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier + // For ICMP, it's "code". //todo also store "type?" would need to decode replies, which sucks LocalPort uint16 // RemotePort is the source port for incoming traffic, or the destination port for outgoing. - // For ICMP, it's "code". //todo also store "type?" would need to decode replies, which sucks + // For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier RemotePort uint16 Protocol uint8 Fragment bool diff --git a/outside.go b/outside.go index 572f1cc8..0c82f389 100644 --- a/outside.go +++ b/outside.go @@ -437,14 +437,9 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { if fp.Fragment { fp.RemotePort = 0 fp.LocalPort = 0 - } else if fp.Protocol == firewall.ProtoICMP { - //todo remove comment - //icmpType := data[ihl] - //icmpCode := data[ihl+1] - //icmpChecksum := data[ihl+2 : ihl+4] - //icmpIdentifier := data[ihl+4 : ihl+6] - fp.RemotePort = uint16(data[ihl+1]) //code - fp.LocalPort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier + } else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier + fp.LocalPort = uint16(data[ihl+1]) //code } else { fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port @@ -455,9 +450,9 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { if fp.Fragment { fp.RemotePort = 0 fp.LocalPort = 0 - } else if fp.Protocol == firewall.ProtoICMP { - fp.RemotePort = uint16(data[ihl+1]) //code - fp.LocalPort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier + } else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier + fp.LocalPort = uint16(data[ihl+1]) //code } else { fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port diff --git a/snat.go b/snat.go index e9c7a804..3164b641 100644 --- a/snat.go +++ b/snat.go @@ -59,3 +59,33 @@ func recalcTCPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.Ad const offsetInsideHeader = 16 recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP) } + +func calcNewICMPChecksum(oldChecksum uint16, oldCode uint16, newCode uint16, oldID uint16, newID uint16) uint16 { + // Start with inverted checksum + checksum := uint32(^oldChecksum) + + // Subtract old stuff + checksum += uint32(^oldCode) + checksum += uint32(^oldID) + + // Add new stuff + checksum += uint32(newCode) + checksum += uint32(newID) + + // Fold carries + for checksum > 0xFFFF { + checksum = (checksum & 0xFFFF) + (checksum >> 16) + } + + // Return ones' complement + return ^uint16(checksum) +} + +func recalcICMPv4Checksum(data []byte, oldCode uint16, newCode uint16, oldID uint16, newID uint16) { + const offsetInsideHeader = 2 + ipHeaderOffset := int(data[0]&0x0F) * 4 + offset := ipHeaderOffset + offsetInsideHeader + oldChecksum := binary.BigEndian.Uint16(data[offset : offset+2]) + checksum := calcNewICMPChecksum(oldChecksum, oldCode, newCode, oldID, newID) + binary.BigEndian.PutUint16(data[offset:offset+2], checksum) +} From cf2b5455bf96486053c71a0298f0946dc83d7040 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Thu, 22 Jan 2026 11:32:37 -0600 Subject: [PATCH 16/16] I don't like this ICMP behavior change --- firewall.go | 23 +++++++++++++++++++++++ firewall_test.go | 18 ++++++++++++++++-- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/firewall.go b/firewall.go index 1ff08d86..457210a6 100644 --- a/firewall.go +++ b/firewall.go @@ -654,6 +654,19 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, c, ok := conntrack.Conns[fp] + if !ok && fp.Protocol == firewall.ProtoICMP { + //todo this seems like it will also bite me + oldRemote := fp.RemotePort + oldLocal := fp.LocalPort + fp.RemotePort = 0 + fp.LocalPort = 0 + c, ok = conntrack.Conns[fp] + if ok { + fp.RemotePort = oldRemote + fp.LocalPort = oldLocal + } + } + if !ok { conntrack.Unlock() return nil @@ -755,6 +768,16 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) *conn { c.rulesVersion = f.rulesVersion c.Expires = time.Now().Add(timeout) conntrack.Conns[fp] = c + + //todo will this bite me somehow? + if fp.Protocol == firewall.ProtoICMP { + //not required for ICMPv6 because we don't decode or SNAT it + //create a duplicate conntrack entry with all the port information zeroed? + fp.RemotePort = 0 + fp.LocalPort = 0 + conntrack.Conns[fp] = c + } + conntrack.Unlock() return c } diff --git a/firewall_test.go b/firewall_test.go index a1f5cacf..77e7cc69 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -844,7 +844,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { t.Run("Any proto, any port", func(t *testing.T) { fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) - t.Run("zero ports, still blocked", func(t *testing.T) { + t.Run("zero ports, allowed", func(t *testing.T) { p := templ.Copy() p.LocalPort = 0 p.RemotePort = 0 @@ -857,7 +857,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) }) - t.Run("nonzero ports, still blocked", func(t *testing.T) { + t.Run("nonzero ports, allowed", func(t *testing.T) { p := templ.Copy() p.LocalPort = 0xabcd p.RemotePort = 0x1234 @@ -869,6 +869,20 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { //now also allow outbound require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) }) + + t.Run("nonzero ports, allowed", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound with a different ID + p.RemotePort++ + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) }) }