From 2c30c2edb93e9d116c3a4832684b85fc9bcde4eb Mon Sep 17 00:00:00 2001 From: JackDoan Date: Wed, 21 Jan 2026 13:22:59 -0600 Subject: [PATCH] SNAT the ICMPv4 identifier --- firewall.go | 12 ++++++++---- firewall/packet.go | 4 ++-- outside.go | 17 ++++++----------- snat.go | 30 ++++++++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 17 deletions(-) diff --git a/firewall.go b/firewall.go index 3ddd85bf..1ff08d86 100644 --- a/firewall.go +++ b/firewall.go @@ -449,6 +449,10 @@ func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn) netip.Addr dstport := ipHeaderLen + 2 switch fp.Protocol { + case firewall.ProtoICMP: + binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], c.snat.Src.Port()) + icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on this yet (but Linux would) + recalcICMPv4Checksum(data, icmpCode, icmpCode, c.snat.SnatPort, c.snat.Src.Port()) case firewall.ProtoUDP: binary.BigEndian.PutUint16(data[dstport:dstport+2], c.snat.Src.Port()) recalcUDPv4Checksum(data, oldIP, c.snat.Src) @@ -460,7 +464,6 @@ func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn) netip.Addr } func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) { - //todo set srcport if c.snat.Valid() { //old flow fp.RemoteAddr = f.snatAddr @@ -471,7 +474,6 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo c.snat.SrcVpnIp = hostinfo.vpnAddrs[0] fp.RemoteAddr = f.snatAddr - //find a new port to use, if needed for { existingFlow := f.peek(*fp) //locking and unlocking for each peek is slow, but simple for now @@ -494,7 +496,7 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo return } - newIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort) + newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort) //change src IP copy(data[12:], f.snatAddr.AsSlice()) recalcIPv4Checksum(data, c.snat.Src.Addr(), newIP.Addr()) @@ -502,7 +504,9 @@ func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo switch fp.Protocol { case firewall.ProtoICMP: - //todo! + binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], c.snat.SnatPort) + icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on this yet (but Linux would) + recalcICMPv4Checksum(data, icmpCode, icmpCode, c.snat.Src.Port(), c.snat.SnatPort) case firewall.ProtoUDP: //src port is at offset 0 binary.BigEndian.PutUint16(data[ipHeaderLen:ipHeaderLen+2], c.snat.SnatPort) diff --git a/firewall/packet.go b/firewall/packet.go index 8482d476..ce00129c 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -23,10 +23,10 @@ type Packet struct { LocalAddr netip.Addr RemoteAddr netip.Addr // LocalPort is the destination port for incoming traffic, or the source port for outgoing. - // For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier + // For ICMP, it's "code". //todo also store "type?" would need to decode replies, which sucks LocalPort uint16 // RemotePort is the source port for incoming traffic, or the destination port for outgoing. - // For ICMP, it's "code". //todo also store "type?" would need to decode replies, which sucks + // For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier RemotePort uint16 Protocol uint8 Fragment bool diff --git a/outside.go b/outside.go index 572f1cc8..0c82f389 100644 --- a/outside.go +++ b/outside.go @@ -437,14 +437,9 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { if fp.Fragment { fp.RemotePort = 0 fp.LocalPort = 0 - } else if fp.Protocol == firewall.ProtoICMP { - //todo remove comment - //icmpType := data[ihl] - //icmpCode := data[ihl+1] - //icmpChecksum := data[ihl+2 : ihl+4] - //icmpIdentifier := data[ihl+4 : ihl+6] - fp.RemotePort = uint16(data[ihl+1]) //code - fp.LocalPort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier + } else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier + fp.LocalPort = uint16(data[ihl+1]) //code } else { fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port @@ -455,9 +450,9 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { if fp.Fragment { fp.RemotePort = 0 fp.LocalPort = 0 - } else if fp.Protocol == firewall.ProtoICMP { - fp.RemotePort = uint16(data[ihl+1]) //code - fp.LocalPort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier + } else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier + fp.LocalPort = uint16(data[ihl+1]) //code } else { fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port diff --git a/snat.go b/snat.go index e9c7a804..3164b641 100644 --- a/snat.go +++ b/snat.go @@ -59,3 +59,33 @@ func recalcTCPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.Ad const offsetInsideHeader = 16 recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP) } + +func calcNewICMPChecksum(oldChecksum uint16, oldCode uint16, newCode uint16, oldID uint16, newID uint16) uint16 { + // Start with inverted checksum + checksum := uint32(^oldChecksum) + + // Subtract old stuff + checksum += uint32(^oldCode) + checksum += uint32(^oldID) + + // Add new stuff + checksum += uint32(newCode) + checksum += uint32(newID) + + // Fold carries + for checksum > 0xFFFF { + checksum = (checksum & 0xFFFF) + (checksum >> 16) + } + + // Return ones' complement + return ^uint16(checksum) +} + +func recalcICMPv4Checksum(data []byte, oldCode uint16, newCode uint16, oldID uint16, newID uint16) { + const offsetInsideHeader = 2 + ipHeaderOffset := int(data[0]&0x0F) * 4 + offset := ipHeaderOffset + offsetInsideHeader + oldChecksum := binary.BigEndian.Uint16(data[offset : offset+2]) + checksum := calcNewICMPChecksum(oldChecksum, oldCode, newCode, oldID, newID) + binary.BigEndian.PutUint16(data[offset:offset+2], checksum) +}