From 037459ef73f3d2ecd539f71e90ca80a0939a7700 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 27 Feb 2026 17:49:31 -0600 Subject: [PATCH] Review nits --- firewall.go | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/firewall.go b/firewall.go index 64ccb262..85e9f666 100644 --- a/firewall.go +++ b/firewall.go @@ -23,9 +23,14 @@ import ( "github.com/slackhq/nebula/firewall" ) -var ErrCannotSNAT = errors.New("cannot snat this packet") +var ErrCannotSNAT = errors.New("cannot SNAT this packet") var ErrSNATIdentityMismatch = errors.New("refusing to SNAT for mismatched host") +const ipv4SourcePosition = 12 +const ipv4DestinationPosition = 16 +const sourcePortOffset = 0 +const destinationPortOffset = 2 + type FirewallInterface interface { AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error } @@ -459,7 +464,7 @@ func (f *Firewall) unSnat(data []byte, fp *firewall.Packet) netip.Addr { return netip.Addr{} } oldIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort) - rewritePacket(data, fp, oldIP, c.snat.Src, 16, 2) + rewritePacket(data, fp, oldIP, c.snat.Src, ipv4DestinationPosition, destinationPortOffset) return c.snat.SrcVpnIp } @@ -496,6 +501,7 @@ func (f *Firewall) findUsableSNATPort(fp *firewall.Packet, c *conn) error { if !ok { //yay, we can use this port //track the snatted flow with the same expiration as the unsnatted version + c.snat.SnatPort = fp.RemotePort conntrack.Conns[*fp] = c return nil } @@ -538,13 +544,12 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo c.snat = nil return err } - c.snat.SnatPort = fp.RemotePort //may have been updated inside f.findUsableSNATPort } else { return ErrCannotSNAT } newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort) - rewritePacket(data, fp, c.snat.Src, newIP, 12, 0) + rewritePacket(data, fp, c.snat.Src, newIP, ipv4SourcePosition, sourcePortOffset) return nil } @@ -695,18 +700,9 @@ func (f *Firewall) EmitStats() { } func (f *Firewall) peek(fp firewall.Packet) *conn { - conntrack := f.Conntrack - conntrack.Lock() - - // Purge every time we test - ep, has := conntrack.TimerWheel.Purge() - if has { - f.evict(ep) - } - - c := conntrack.Conns[fp] - - conntrack.Unlock() + f.Conntrack.Lock() + c := f.Conntrack.Conns[fp] + f.Conntrack.Unlock() return c }