From 5fb407e92953e4c41bc044266a3cfd7585888758 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 21 Jan 2026 12:54:42 -0600 Subject: [PATCH] conntrack ICMPv4 based on code and identifier, without changing firewall filter behavior --- firewall.go | 7 +++ firewall/packet.go | 6 +- firewall_test.go | 139 +++++++++++++++++++++++++++++++++++++++++++++ outside.go | 25 +++++--- routing/balance.go | 7 ++- 5 files changed, 175 insertions(+), 9 deletions(-) diff --git a/firewall.go b/firewall.go index 7cfc21f1..3ddd85bf 100644 --- a/firewall.go +++ b/firewall.go @@ -830,6 +830,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer var port int32 + if p.Protocol == firewall.ProtoICMP { + // port numbers are re-used for connection tracking and SNAT, + // but we don't want to actually filter on them for ICMP + // ICMP6 is omitted because we don't attempt to parse code/identifier/etc out of ICMP6 + return fp[firewall.PortAny].match(p, c, caPool) + } + if p.Fragment { port = firewall.PortFragment } else if incoming { diff --git a/firewall/packet.go b/firewall/packet.go index ade2ad55..8482d476 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -22,7 +22,11 @@ const ( type Packet struct { LocalAddr netip.Addr RemoteAddr netip.Addr - LocalPort uint16 + // LocalPort is the destination port for incoming traffic, or the source port for outgoing. + // For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier + LocalPort uint16 + // RemotePort is the source port for incoming traffic, or the destination port for outgoing. + // For ICMP, it's "code". //todo also store "type?" would need to decode replies, which sucks RemotePort uint16 Protocol uint8 Fragment bool diff --git a/firewall_test.go b/firewall_test.go index 1430c3da..a1f5cacf 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -734,6 +734,145 @@ func TestFirewall_DropConntrackReload(t *testing.T) { assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule) } +func TestFirewall_ICMPPortBehavior(t *testing.T) { + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) + + network := netip.MustParsePrefix("1.2.3.4/24") + + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + }, + InvertedGroups: map[string]struct{}{"default-group": {}}, + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c, + }, + vpnAddrs: []netip.Addr{network.Addr()}, + } + h.buildNetworks(myVpnNetworksTable, c.Certificate) + + cp := cert.NewCAPool() + + templ := firewall.Packet{ + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), + Protocol: firewall.ProtoICMP, + Fragment: false, + } + + t.Run("ICMP allowed", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) + + t.Run("nonzero ports", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) + }) + + t.Run("Any proto, some ports allowed", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + }) + + t.Run("nonzero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + }) + + t.Run("nonzero, matching ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 80 + p.RemotePort = 80 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + }) + }) + t.Run("Any proto, any port", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) + + t.Run("nonzero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) + }) + +} + func TestFirewall_DropIPSpoofing(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} diff --git a/outside.go b/outside.go index a1fa44bf..572f1cc8 100644 --- a/outside.go +++ b/outside.go @@ -329,7 +329,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { switch proto { case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader: fp.Protocol = uint8(proto) - fp.RemotePort = 0 + fp.RemotePort = 0 //we don't attempt to parse ICMPv6 because we don't SNAT it fp.LocalPort = 0 fp.Fragment = false return nil @@ -434,22 +434,33 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { if incoming { fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16]) fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20]) - if fp.Fragment || fp.Protocol == firewall.ProtoICMP { + if fp.Fragment { fp.RemotePort = 0 fp.LocalPort = 0 + } else if fp.Protocol == firewall.ProtoICMP { + //todo remove comment + //icmpType := data[ihl] + //icmpCode := data[ihl+1] + //icmpChecksum := data[ihl+2 : ihl+4] + //icmpIdentifier := data[ihl+4 : ihl+6] + fp.RemotePort = uint16(data[ihl+1]) //code + fp.LocalPort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier } else { - fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) + fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port + fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port } } else { fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16]) fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20]) - if fp.Fragment || fp.Protocol == firewall.ProtoICMP { + if fp.Fragment { fp.RemotePort = 0 fp.LocalPort = 0 + } else if fp.Protocol == firewall.ProtoICMP { + fp.RemotePort = uint16(data[ihl+1]) //code + fp.LocalPort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier } else { - fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) + fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port } } diff --git a/routing/balance.go b/routing/balance.go index 6f524970..22459113 100644 --- a/routing/balance.go +++ b/routing/balance.go @@ -12,7 +12,12 @@ import ( // - https://github.com/skeeto/hash-prospector // [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501 func hashPacket(p *firewall.Packet) int { - x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort) + var x uint32 + if p.Protocol == firewall.ProtoICMP { + x = 0 //Don't attempt to use ICMP's code/id/etc to balance + } else { + x = (uint32(p.LocalPort) << 16) | uint32(p.RemotePort) + } x ^= x >> 16 x *= 0x21f0aaad x ^= x >> 15