mirror of
https://github.com/slackhq/nebula.git
synced 2026-01-23 13:01:34 -08:00
Merge cf2b5455bf into 0b02d982b2
This commit is contained in:
commit
e25d2aeb73
21 changed files with 631 additions and 106 deletions
|
|
@ -441,7 +441,7 @@ func (c *certificateV2) validate() error {
|
|||
}
|
||||
} else if network.Addr().Is4() {
|
||||
if !hasV4Networks {
|
||||
return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
|
||||
//return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
248
firewall.go
248
firewall.go
|
|
@ -2,6 +2,7 @@ package nebula
|
|||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -26,6 +27,19 @@ type FirewallInterface interface {
|
|||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
|
||||
}
|
||||
|
||||
type snatInfo struct {
|
||||
//Src is the source IP+port to write into unsafe-route-bound packet
|
||||
Src netip.AddrPort
|
||||
//SrcVpnIp is the overlay IP associated with this flow. It's needed to associate reply traffic so we can get it back to the right host.
|
||||
SrcVpnIp netip.Addr
|
||||
//SnatPort is the port to rewrite into an overlay-bound packet
|
||||
SnatPort uint16
|
||||
}
|
||||
|
||||
func (s *snatInfo) Valid() bool {
|
||||
return s.Src.IsValid()
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
Expires time.Time // Time when this conntrack entry will expire
|
||||
|
||||
|
|
@ -34,6 +48,9 @@ type conn struct {
|
|||
// fields pack for free after the uint32 above
|
||||
incoming bool
|
||||
rulesVersion uint16
|
||||
|
||||
//for SNAT support
|
||||
snat snatInfo
|
||||
}
|
||||
|
||||
// TODO: need conntrack max tracked connections handling
|
||||
|
|
@ -66,6 +83,7 @@ type Firewall struct {
|
|||
defaultLocalCIDRAny bool
|
||||
incomingMetrics firewallMetrics
|
||||
outgoingMetrics firewallMetrics
|
||||
snatAddr netip.Addr
|
||||
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
|
@ -149,12 +167,14 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|||
tmax = defaultTimeout
|
||||
}
|
||||
|
||||
hasV4Networks := false
|
||||
routableNetworks := new(bart.Lite)
|
||||
var assignedNetworks []netip.Prefix
|
||||
for _, network := range c.Networks() {
|
||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||
routableNetworks.Insert(nprefix)
|
||||
assignedNetworks = append(assignedNetworks, network)
|
||||
hasV4Networks = hasV4Networks || network.Addr().Is4()
|
||||
}
|
||||
|
||||
hasUnsafeNetworks := false
|
||||
|
|
@ -163,6 +183,11 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|||
hasUnsafeNetworks = true
|
||||
}
|
||||
|
||||
snatAddr := netip.Addr{}
|
||||
if hasUnsafeNetworks && !hasV4Networks {
|
||||
snatAddr = netip.MustParseAddr("169.254.55.96") //todo this needs to come from the config, or perhaps the tun
|
||||
}
|
||||
|
||||
return &Firewall{
|
||||
Conntrack: &FirewallConntrack{
|
||||
Conns: make(map[firewall.Packet]*conn),
|
||||
|
|
@ -176,6 +201,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|||
routableNetworks: routableNetworks,
|
||||
assignedNetworks: assignedNetworks,
|
||||
hasUnsafeNetworks: hasUnsafeNetworks,
|
||||
snatAddr: snatAddr,
|
||||
l: l,
|
||||
|
||||
incomingMetrics: firewallMetrics{
|
||||
|
|
@ -401,22 +427,131 @@ var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate
|
|||
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
|
||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||
|
||||
func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn) netip.Addr {
|
||||
if c == nil {
|
||||
//unfortunately this needs to lock. Surely there's a better way, but I need to make this flow at all first.
|
||||
c = f.peek(*fp)
|
||||
}
|
||||
if c == nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
if !c.snat.Valid() {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
oldIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort)
|
||||
|
||||
//change dst IP
|
||||
copy(data[16:], c.snat.Src.Addr().AsSlice())
|
||||
recalcIPv4Checksum(data, oldIP.Addr(), c.snat.Src.Addr())
|
||||
ipHeaderLen := int(data[0]&0x0F) * 4
|
||||
//dst port is at offset 2
|
||||
dstport := ipHeaderLen + 2
|
||||
|
||||
switch fp.Protocol {
|
||||
case firewall.ProtoICMP:
|
||||
binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], c.snat.Src.Port())
|
||||
icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on this yet (but Linux would)
|
||||
recalcICMPv4Checksum(data, icmpCode, icmpCode, c.snat.SnatPort, c.snat.Src.Port())
|
||||
case firewall.ProtoUDP:
|
||||
binary.BigEndian.PutUint16(data[dstport:dstport+2], c.snat.Src.Port())
|
||||
recalcUDPv4Checksum(data, oldIP, c.snat.Src)
|
||||
case firewall.ProtoTCP:
|
||||
binary.BigEndian.PutUint16(data[dstport:dstport+2], c.snat.Src.Port())
|
||||
recalcTCPv4Checksum(data, oldIP, c.snat.Src)
|
||||
}
|
||||
return c.snat.SrcVpnIp
|
||||
}
|
||||
|
||||
func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) {
|
||||
if c.snat.Valid() {
|
||||
//old flow
|
||||
fp.RemoteAddr = f.snatAddr
|
||||
fp.RemotePort = c.snat.SnatPort
|
||||
} else if hostinfo.vpnAddrs[0].Is6() {
|
||||
//we got a new flow
|
||||
c.snat.Src = netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort)
|
||||
c.snat.SrcVpnIp = hostinfo.vpnAddrs[0]
|
||||
|
||||
fp.RemoteAddr = f.snatAddr
|
||||
//find a new port to use, if needed
|
||||
for {
|
||||
existingFlow := f.peek(*fp) //locking and unlocking for each peek is slow, but simple for now
|
||||
if existingFlow == nil {
|
||||
break //yay, we can use this port
|
||||
}
|
||||
//increment and retry. There's probably better strategies out there
|
||||
fp.RemotePort++
|
||||
if fp.RemotePort < 0x7ff {
|
||||
fp.RemotePort += 0x7ff // keep it ephemeral for now
|
||||
} //todo if we're totally out of ports this loops forever. Probably not good.
|
||||
}
|
||||
c.snat.SnatPort = fp.RemotePort
|
||||
} else {
|
||||
f.l.WithFields(logrus.Fields{
|
||||
"fp": *fp,
|
||||
"conn": *c,
|
||||
"hostinfo": hostinfo,
|
||||
}).Error("this packet cannot be snatted")
|
||||
return
|
||||
}
|
||||
|
||||
newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort)
|
||||
//change src IP
|
||||
copy(data[12:], f.snatAddr.AsSlice())
|
||||
recalcIPv4Checksum(data, c.snat.Src.Addr(), newIP.Addr())
|
||||
ipHeaderLen := int(data[0]&0x0F) * 4
|
||||
|
||||
switch fp.Protocol {
|
||||
case firewall.ProtoICMP:
|
||||
binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], c.snat.SnatPort)
|
||||
icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on this yet (but Linux would)
|
||||
recalcICMPv4Checksum(data, icmpCode, icmpCode, c.snat.Src.Port(), c.snat.SnatPort)
|
||||
case firewall.ProtoUDP:
|
||||
//src port is at offset 0
|
||||
binary.BigEndian.PutUint16(data[ipHeaderLen:ipHeaderLen+2], c.snat.SnatPort)
|
||||
recalcUDPv4Checksum(data, c.snat.Src, newIP)
|
||||
case firewall.ProtoTCP:
|
||||
//src port is at offset 0
|
||||
binary.BigEndian.PutUint16(data[ipHeaderLen:ipHeaderLen+2], c.snat.SnatPort)
|
||||
recalcTCPv4Checksum(data, c.snat.Src, newIP)
|
||||
}
|
||||
}
|
||||
|
||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||
// returns nil if the packet should not be dropped.
|
||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
||||
func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
||||
specialSnatMode := f.hasUnsafeNetworks && fp.IsIPv4() && h.HasOnlyV6Addresses() //todo I wish I only set this once somehow
|
||||
|
||||
table := f.OutRules
|
||||
if incoming {
|
||||
table = f.InRules
|
||||
}
|
||||
|
||||
// Check if we spoke to this tuple, if we did then allow this packet
|
||||
if f.inConns(fp, h, caPool, localCache) {
|
||||
return nil
|
||||
// Check the cache first, iff not snatting
|
||||
if localCache != nil && !specialSnatMode {
|
||||
if _, ok := localCache[fp]; ok {
|
||||
return nil //packet matched the cache, we're not snatting, we can return early!
|
||||
}
|
||||
}
|
||||
c := f.inConns(fp, h, caPool, localCache)
|
||||
if c != nil {
|
||||
//can't return yet, need to snat maybe
|
||||
goto snat
|
||||
}
|
||||
|
||||
// Make sure remote address matches nebula certificate, and determine how to treat it
|
||||
if h.networks == nil {
|
||||
// Simple case: Certificate has one address and no unsafe networks
|
||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
if !specialSnatMode {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP //todo!
|
||||
} //else we're in special snat mode, and we need to apply more checks below
|
||||
} //else? all good, fall through
|
||||
} else {
|
||||
//todo check for srcsnortaddr here too?
|
||||
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
|
|
@ -426,9 +561,12 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||
case NetworkTypeVPN:
|
||||
break // nothing special
|
||||
case NetworkTypeVPNPeer:
|
||||
//todo we might need a specialSnatMode case in here to handle routers with v4 addresses when we don't also have a v4 address?
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
||||
case NetworkTypeUnsafe:
|
||||
//intentionally excluding f.hasUnsafeNetworks -- this is what lets routers talk back to us with our unsafe traffic!
|
||||
specialSnatMode = fp.IsIPv4() && h.HasOnlyV6Addresses() && f.assignedNetworks[0].Addr().Is6()
|
||||
break // nothing special, one day this may have different FW rules
|
||||
default:
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
|
|
@ -437,16 +575,12 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||
}
|
||||
|
||||
// Make sure we are supposed to be handling this local ip address
|
||||
if !f.routableNetworks.Contains(fp.LocalAddr) {
|
||||
//todo I'm not sure I trust this heuristic
|
||||
if !specialSnatMode && !f.routableNetworks.Contains(fp.LocalAddr) {
|
||||
f.metrics(incoming).droppedLocalAddr.Inc(1)
|
||||
return ErrInvalidLocalIP
|
||||
}
|
||||
|
||||
table := f.OutRules
|
||||
if incoming {
|
||||
table = f.InRules
|
||||
}
|
||||
|
||||
// We now know which firewall table to check against
|
||||
if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
|
||||
f.metrics(incoming).droppedNoRule.Inc(1)
|
||||
|
|
@ -454,7 +588,16 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||
}
|
||||
|
||||
// We always want to conntrack since it is a faster operation
|
||||
f.addConn(fp, incoming)
|
||||
c = f.addConn(fp, incoming)
|
||||
|
||||
snat:
|
||||
if incoming {
|
||||
if specialSnatMode {
|
||||
//todo do not snat if you are not a router for the destination -- for now, just if you're not a router
|
||||
f.applySnat(pkt, &fp, c, h)
|
||||
f.dupeConn(fp, c) //track the snatted flow with the same expiration as the unsnatted version
|
||||
}
|
||||
} //outgoing snat is handled before this function is called (for now!)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -483,12 +626,23 @@ func (f *Firewall) EmitStats() {
|
|||
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
||||
}
|
||||
|
||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
|
||||
if localCache != nil {
|
||||
if _, ok := localCache[fp]; ok {
|
||||
return true
|
||||
}
|
||||
func (f *Firewall) peek(fp firewall.Packet) *conn {
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
|
||||
// Purge every time we test
|
||||
ep, has := conntrack.TimerWheel.Purge()
|
||||
if has {
|
||||
f.evict(ep)
|
||||
}
|
||||
|
||||
c := conntrack.Conns[fp]
|
||||
|
||||
conntrack.Unlock()
|
||||
return c
|
||||
}
|
||||
|
||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn {
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
|
||||
|
|
@ -500,9 +654,22 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||
|
||||
c, ok := conntrack.Conns[fp]
|
||||
|
||||
if !ok && fp.Protocol == firewall.ProtoICMP {
|
||||
//todo this seems like it will also bite me
|
||||
oldRemote := fp.RemotePort
|
||||
oldLocal := fp.LocalPort
|
||||
fp.RemotePort = 0
|
||||
fp.LocalPort = 0
|
||||
c, ok = conntrack.Conns[fp]
|
||||
if ok {
|
||||
fp.RemotePort = oldRemote
|
||||
fp.LocalPort = oldLocal
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
conntrack.Unlock()
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.rulesVersion != f.rulesVersion {
|
||||
|
|
@ -525,7 +692,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||
}
|
||||
delete(conntrack.Conns, fp)
|
||||
conntrack.Unlock()
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
|
|
@ -555,12 +722,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||
localCache[fp] = struct{}{}
|
||||
}
|
||||
|
||||
return true
|
||||
return c
|
||||
}
|
||||
|
||||
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
||||
func (f *Firewall) packetTimeout(fp firewall.Packet) time.Duration {
|
||||
var timeout time.Duration
|
||||
c := &conn{}
|
||||
|
||||
switch fp.Protocol {
|
||||
case firewall.ProtoTCP:
|
||||
|
|
@ -570,7 +736,25 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
|||
default:
|
||||
timeout = f.DefaultTimeout
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
func (f *Firewall) dupeConn(fp firewall.Packet, c *conn) {
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
if _, ok := conntrack.Conns[fp]; !ok {
|
||||
conntrack.TimerWheel.Advance(time.Now())
|
||||
conntrack.TimerWheel.Add(fp, f.packetTimeout(fp))
|
||||
}
|
||||
|
||||
conntrack.Conns[fp] = c
|
||||
conntrack.Unlock()
|
||||
}
|
||||
|
||||
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) *conn {
|
||||
c := &conn{}
|
||||
|
||||
timeout := f.packetTimeout(fp)
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
if _, ok := conntrack.Conns[fp]; !ok {
|
||||
|
|
@ -584,7 +768,18 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
|||
c.rulesVersion = f.rulesVersion
|
||||
c.Expires = time.Now().Add(timeout)
|
||||
conntrack.Conns[fp] = c
|
||||
|
||||
//todo will this bite me somehow?
|
||||
if fp.Protocol == firewall.ProtoICMP {
|
||||
//not required for ICMPv6 because we don't decode or SNAT it
|
||||
//create a duplicate conntrack entry with all the port information zeroed?
|
||||
fp.RemotePort = 0
|
||||
fp.LocalPort = 0
|
||||
conntrack.Conns[fp] = c
|
||||
}
|
||||
|
||||
conntrack.Unlock()
|
||||
return c
|
||||
}
|
||||
|
||||
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
||||
|
|
@ -662,6 +857,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
|
|||
|
||||
var port int32
|
||||
|
||||
if p.Protocol == firewall.ProtoICMP {
|
||||
// port numbers are re-used for connection tracking and SNAT,
|
||||
// but we don't want to actually filter on them for ICMP
|
||||
// ICMP6 is omitted because we don't attempt to parse code/identifier/etc out of ICMP6
|
||||
return fp[firewall.PortAny].match(p, c, caPool)
|
||||
}
|
||||
|
||||
if p.Fragment {
|
||||
port = firewall.PortFragment
|
||||
} else if incoming {
|
||||
|
|
|
|||
|
|
@ -22,12 +22,20 @@ const (
|
|||
type Packet struct {
|
||||
LocalAddr netip.Addr
|
||||
RemoteAddr netip.Addr
|
||||
LocalPort uint16
|
||||
// LocalPort is the destination port for incoming traffic, or the source port for outgoing.
|
||||
// For ICMP, it's "code". //todo also store "type?" would need to decode replies, which sucks
|
||||
LocalPort uint16
|
||||
// RemotePort is the source port for incoming traffic, or the destination port for outgoing.
|
||||
// For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier
|
||||
RemotePort uint16
|
||||
Protocol uint8
|
||||
Fragment bool
|
||||
}
|
||||
|
||||
func (fp *Packet) IsIPv4() bool {
|
||||
return fp.LocalAddr.Is4() && fp.RemoteAddr.Is4()
|
||||
}
|
||||
|
||||
func (fp *Packet) Copy() *Packet {
|
||||
return &Packet{
|
||||
LocalAddr: fp.LocalAddr,
|
||||
|
|
|
|||
213
firewall_test.go
213
firewall_test.go
|
|
@ -212,44 +212,44 @@ func TestFirewall_Drop(t *testing.T) {
|
|||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
|
||||
// test remote mismatch
|
||||
oldRemote := p.RemoteAddr
|
||||
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
p.RemoteAddr = oldRemote
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_DropV6(t *testing.T) {
|
||||
|
|
@ -291,44 +291,44 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
|
||||
// test remote mismatch
|
||||
oldRemote := p.RemoteAddr
|
||||
p.RemoteAddr = netip.MustParseAddr("fd12::56")
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
p.RemoteAddr = oldRemote
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
|
|
@ -536,10 +536,10 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||
cp := cert.NewCAPool()
|
||||
|
||||
// h1/c1 lacks the proper groups
|
||||
require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
|
||||
require.ErrorIs(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrNoMatchingRule)
|
||||
// c has the proper groups
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop3(t *testing.T) {
|
||||
|
|
@ -617,18 +617,18 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||
cp := cert.NewCAPool()
|
||||
|
||||
// c1 should pass because host match
|
||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
|
||||
// c2 should pass because ca sha match
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h2, cp, nil))
|
||||
// c3 should fail because no match
|
||||
resetConntrack(fw)
|
||||
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// Test a remote address match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop3V6(t *testing.T) {
|
||||
|
|
@ -666,7 +666,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
cp := cert.NewCAPool()
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
|
|
@ -708,12 +708,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
|
||||
oldFw := fw
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
|
|
@ -722,7 +722,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
// Allow outbound because conntrack and new rules allow port 10
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
|
||||
oldFw = fw
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
|
|
@ -731,7 +731,160 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
// Drop outbound because conntrack doesn't match new ruleset
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
}
|
||||
|
||||
func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
|
||||
|
||||
network := netip.MustParsePrefix("1.2.3.4/24")
|
||||
|
||||
c := cert.CachedCertificate{
|
||||
Certificate: &dummyCert{
|
||||
name: "host1",
|
||||
networks: []netip.Prefix{network},
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
},
|
||||
InvertedGroups: map[string]struct{}{"default-group": {}},
|
||||
}
|
||||
h := HostInfo{
|
||||
ConnectionState: &ConnectionState{
|
||||
peerCert: &c,
|
||||
},
|
||||
vpnAddrs: []netip.Addr{network.Addr()},
|
||||
}
|
||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
||||
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
templ := firewall.Packet{
|
||||
LocalAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
Protocol: firewall.ProtoICMP,
|
||||
Fragment: false,
|
||||
}
|
||||
|
||||
t.Run("ICMP allowed", func(t *testing.T) {
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||
t.Run("zero ports", func(t *testing.T) {
|
||||
p := templ.Copy()
|
||||
p.LocalPort = 0
|
||||
p.RemotePort = 0
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
|
||||
//now also allow outbound
|
||||
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
|
||||
})
|
||||
|
||||
t.Run("nonzero ports", func(t *testing.T) {
|
||||
p := templ.Copy()
|
||||
p.LocalPort = 0xabcd
|
||||
p.RemotePort = 0x1234
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
|
||||
//now also allow outbound
|
||||
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Any proto, some ports allowed", func(t *testing.T) {
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", ""))
|
||||
t.Run("zero ports, still blocked", func(t *testing.T) {
|
||||
p := templ.Copy()
|
||||
p.LocalPort = 0
|
||||
p.RemotePort = 0
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
//now also allow outbound
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
})
|
||||
|
||||
t.Run("nonzero ports, still blocked", func(t *testing.T) {
|
||||
p := templ.Copy()
|
||||
p.LocalPort = 0xabcd
|
||||
p.RemotePort = 0x1234
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
//now also allow outbound
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
})
|
||||
|
||||
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
|
||||
p := templ.Copy()
|
||||
p.LocalPort = 80
|
||||
p.RemotePort = 80
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
//now also allow outbound
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
})
|
||||
})
|
||||
t.Run("Any proto, any port", func(t *testing.T) {
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||
t.Run("zero ports, allowed", func(t *testing.T) {
|
||||
p := templ.Copy()
|
||||
p.LocalPort = 0
|
||||
p.RemotePort = 0
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
|
||||
//now also allow outbound
|
||||
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
|
||||
})
|
||||
|
||||
t.Run("nonzero ports, allowed", func(t *testing.T) {
|
||||
p := templ.Copy()
|
||||
p.LocalPort = 0xabcd
|
||||
p.RemotePort = 0x1234
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
|
||||
//now also allow outbound
|
||||
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
|
||||
})
|
||||
|
||||
t.Run("nonzero ports, allowed", func(t *testing.T) {
|
||||
p := templ.Copy()
|
||||
p.LocalPort = 0xabcd
|
||||
p.RemotePort = 0x1234
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
|
||||
//now also allow outbound with a different ID
|
||||
p.RemotePort++
|
||||
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
|
||||
})
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||
|
|
@ -777,7 +930,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrInvalidRemoteIP)
|
||||
}
|
||||
|
||||
func BenchmarkLookup(b *testing.B) {
|
||||
|
|
@ -1184,7 +1337,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
|
|||
t.Helper()
|
||||
cp := cert.NewCAPool()
|
||||
resetConntrack(fw)
|
||||
err := fw.Drop(c.p, true, c.h, cp, nil)
|
||||
err := fw.Drop(c.p, nil, true, c.h, cp, nil)
|
||||
if c.err == nil {
|
||||
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
|
||||
} else {
|
||||
|
|
|
|||
10
hostmap.go
10
hostmap.go
|
|
@ -224,6 +224,7 @@ const (
|
|||
NetworkTypeVPNPeer
|
||||
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
|
||||
NetworkTypeUnsafe
|
||||
//todo consider NetworkTypeLinkLocal or NetworkTypeSNAT
|
||||
)
|
||||
|
||||
type HostInfo struct {
|
||||
|
|
@ -277,6 +278,15 @@ type HostInfo struct {
|
|||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (i *HostInfo) HasOnlyV6Addresses() bool {
|
||||
for _, vpnIp := range i.vpnAddrs {
|
||||
if !vpnIp.Is6() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type ViaSender struct {
|
||||
UdpAddr netip.AddrPort
|
||||
relayHI *HostInfo // relayHI is the host info object of the relay
|
||||
|
|
|
|||
26
inside.go
26
inside.go
|
|
@ -48,9 +48,24 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||
return
|
||||
}
|
||||
|
||||
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
var hostinfo *HostInfo
|
||||
var ready bool
|
||||
|
||||
snatMode := f.firewall.snatAddr.IsValid() && fwPacket.RemoteAddr == f.firewall.snatAddr
|
||||
|
||||
if snatMode {
|
||||
//todo unsnat happens here, would be nice to not
|
||||
destVpnAddr := f.firewall.unSnat(packet, fwPacket, nil) //todo bail if we can't unsnat?
|
||||
if destVpnAddr.IsValid() {
|
||||
hostinfo, ready = f.getOrHandshakeNoRouting(destVpnAddr, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
} //otherwise, hostinfo will be nil
|
||||
} else { //if we didn't need to unsnat
|
||||
hostinfo, ready = f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
}
|
||||
|
||||
if hostinfo == nil {
|
||||
f.rejectInside(packet, out, q)
|
||||
|
|
@ -66,10 +81,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||
return
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
dropReason := f.firewall.Drop(*fwPacket, packet, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason == nil {
|
||||
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||
|
||||
} else {
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
|
|
@ -218,7 +232,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||
}
|
||||
|
||||
// check if packet is in outbound fw rules
|
||||
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||
dropReason := f.firewall.Drop(*fp, nil, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||
if dropReason != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("fwPacket", fp).
|
||||
|
|
|
|||
3
main.go
3
main.go
|
|
@ -135,7 +135,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
deviceFactory = overlay.NewDeviceFromConfig
|
||||
}
|
||||
|
||||
tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
|
||||
cs := pki.getCertState()
|
||||
tun, err = deviceFactory(c, l, cs.myVpnNetworks, cs.GetDefaultCertificate().UnsafeNetworks(), routines)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
|
||||
}
|
||||
|
|
|
|||
22
outside.go
22
outside.go
|
|
@ -329,7 +329,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||
switch proto {
|
||||
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
||||
fp.Protocol = uint8(proto)
|
||||
fp.RemotePort = 0
|
||||
fp.RemotePort = 0 //we don't attempt to parse ICMPv6 because we don't SNAT it
|
||||
fp.LocalPort = 0
|
||||
fp.Fragment = false
|
||||
return nil
|
||||
|
|
@ -434,22 +434,28 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||
if incoming {
|
||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
|
||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
|
||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
||||
if fp.Fragment {
|
||||
fp.RemotePort = 0
|
||||
fp.LocalPort = 0
|
||||
} else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP
|
||||
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier
|
||||
fp.LocalPort = uint16(data[ihl+1]) //code
|
||||
} else {
|
||||
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2])
|
||||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
||||
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
|
||||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
|
||||
}
|
||||
} else {
|
||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
|
||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
|
||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
||||
if fp.Fragment {
|
||||
fp.RemotePort = 0
|
||||
fp.LocalPort = 0
|
||||
} else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP
|
||||
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier
|
||||
fp.LocalPort = uint16(data[ihl+1]) //code
|
||||
} else {
|
||||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2])
|
||||
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
||||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
|
||||
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -494,7 +500,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||
return false
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
dropReason := f.firewall.Drop(*fwPacket, out, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason != nil {
|
||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||
// This gives us a buffer to build the reject packet in
|
||||
|
|
|
|||
|
|
@ -13,22 +13,22 @@ import (
|
|||
const DefaultMTU = 1300
|
||||
|
||||
// TODO: We may be able to remove routines
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error)
|
||||
|
||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
switch {
|
||||
case c.GetBool("tun.disabled", false):
|
||||
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||
return tun, nil
|
||||
|
||||
default:
|
||||
return newTun(c, l, vpnNetworks, routines > 1)
|
||||
return newTun(c, l, vpnNetworks, unsafeNetworks, routines > 1)
|
||||
}
|
||||
}
|
||||
|
||||
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return newTunFromFd(c, l, *fd, vpnNetworks)
|
||||
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return newTunFromFd(c, l, *fd, vpnNetworks, unsafeNetworks)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ type tun struct {
|
|||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
||||
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
|
@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
|||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in Android")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ type ifreqAlias6 struct {
|
|||
Lifetime addrLifetime
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
name := c.GetString("tun.dev", "")
|
||||
ifIndex := -1
|
||||
if name != "" && name != "utun" {
|
||||
|
|
@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||
return
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -199,11 +199,11 @@ func (t *tun) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open existing tun device
|
||||
var fd int
|
||||
var err error
|
||||
|
|
|
|||
|
|
@ -28,11 +28,11 @@ type tun struct {
|
|||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||
t := &tun{
|
||||
vpnNetworks: vpnNetworks,
|
||||
|
|
|
|||
|
|
@ -26,14 +26,15 @@ import (
|
|||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
fd int
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
fd int
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
unsafeNetworks []netip.Prefix
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
|
|
@ -71,10 +72,10 @@ type ifreqQLEN struct {
|
|||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -84,7 +85,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
|||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||
|
|
@ -119,7 +120,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -129,11 +130,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||
return t, nil
|
||||
}
|
||||
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
vpnNetworks: vpnNetworks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||
|
|
@ -423,6 +425,27 @@ func (t *tun) setMTU() {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *tun) setSnatRoute() error {
|
||||
snataddr := netip.MustParsePrefix("169.254.55.96/32") //todo get this from elsewhere? Or maybe we should pick it, and feed it back out to the firewall?
|
||||
dr := &net.IPNet{
|
||||
IP: snataddr.Masked().Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(snataddr.Bits(), snataddr.Addr().BitLen()),
|
||||
}
|
||||
|
||||
nr := netlink.Route{
|
||||
LinkIndex: t.deviceIndex,
|
||||
Dst: dr,
|
||||
//todo do we need these other options?
|
||||
//MTU: t.DefaultMTU,
|
||||
//AdvMSS: t.advMSS(Route{}),
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
//Protocol: unix.RTPROT_KERNEL,
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
Type: unix.RTN_UNICAST,
|
||||
}
|
||||
return netlink.RouteReplace(&nr)
|
||||
}
|
||||
|
||||
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
||||
dr := &net.IPNet{
|
||||
IP: cidr.Masked().Addr().AsSlice(),
|
||||
|
|
@ -499,6 +522,18 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||
}
|
||||
}
|
||||
|
||||
onlyV6Addresses := false
|
||||
for _, n := range t.vpnNetworks {
|
||||
if n.Addr().Is6() {
|
||||
onlyV6Addresses = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.unsafeNetworks) != 0 && onlyV6Addresses {
|
||||
return t.setSnatRoute()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -70,11 +70,11 @@ type tun struct {
|
|||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
|
|
|
|||
|
|
@ -63,11 +63,11 @@ type tun struct {
|
|||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ type TestTun struct {
|
|||
TxPackets chan []byte // Packets transmitted outside by nebula
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -49,7 +49,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||
}, nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*TestTun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -38,11 +38,11 @@ type winTun struct {
|
|||
tun *wintun.NativeTun
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (Device, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
||||
err := checkWinTunExists()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import (
|
|||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return NewUserDevice(vpnNetworks)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,12 @@ import (
|
|||
// - https://github.com/skeeto/hash-prospector
|
||||
// [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501
|
||||
func hashPacket(p *firewall.Packet) int {
|
||||
x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort)
|
||||
var x uint32
|
||||
if p.Protocol == firewall.ProtoICMP {
|
||||
x = 0 //Don't attempt to use ICMP's code/id/etc to balance
|
||||
} else {
|
||||
x = (uint32(p.LocalPort) << 16) | uint32(p.RemotePort)
|
||||
}
|
||||
x ^= x >> 16
|
||||
x *= 0x21f0aaad
|
||||
x ^= x >> 15
|
||||
|
|
|
|||
91
snat.go
Normal file
91
snat.go
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func recalcIPv4Checksum(data []byte, oldSrcIP netip.Addr, newSrcIP netip.Addr) {
|
||||
oldChecksum := binary.BigEndian.Uint16(data[10:12])
|
||||
//because of how checksums work, we can re-use this function
|
||||
checksum := calcNewTransportChecksum(oldChecksum, oldSrcIP, 0, newSrcIP, 0)
|
||||
binary.BigEndian.PutUint16(data[10:12], checksum)
|
||||
}
|
||||
|
||||
func calcNewTransportChecksum(oldChecksum uint16, oldSrcIP netip.Addr, oldSrcPort uint16, newSrcIP netip.Addr, newSrcPort uint16) uint16 {
|
||||
oldIP := binary.BigEndian.Uint32(oldSrcIP.AsSlice())
|
||||
newIP := binary.BigEndian.Uint32(newSrcIP.AsSlice())
|
||||
|
||||
// Start with inverted checksum
|
||||
checksum := uint32(^oldChecksum)
|
||||
|
||||
// Subtract old IP (as two 16-bit words)
|
||||
checksum += uint32(^uint16(oldIP >> 16))
|
||||
checksum += uint32(^uint16(oldIP & 0xFFFF))
|
||||
|
||||
// Subtract old port
|
||||
checksum += uint32(^oldSrcPort)
|
||||
|
||||
// Add new IP (as two 16-bit words)
|
||||
checksum += uint32(newIP >> 16)
|
||||
checksum += uint32(newIP & 0xFFFF)
|
||||
|
||||
// Add new port
|
||||
checksum += uint32(newSrcPort)
|
||||
|
||||
// Fold carries
|
||||
for checksum > 0xFFFF {
|
||||
checksum = (checksum & 0xFFFF) + (checksum >> 16)
|
||||
}
|
||||
|
||||
// Return ones' complement
|
||||
return ^uint16(checksum)
|
||||
}
|
||||
|
||||
func recalcV4TransportChecksum(offsetInsideHeader int, data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
|
||||
ipHeaderOffset := int(data[0]&0x0F) * 4
|
||||
offset := ipHeaderOffset + offsetInsideHeader
|
||||
oldcsum := binary.BigEndian.Uint16(data[offset : offset+2])
|
||||
checksum := calcNewTransportChecksum(oldcsum, oldSrcIP.Addr(), oldSrcIP.Port(), newSrcIP.Addr(), newSrcIP.Port())
|
||||
binary.BigEndian.PutUint16(data[offset:offset+2], checksum)
|
||||
}
|
||||
|
||||
func recalcUDPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
|
||||
const offsetInsideHeader = 6
|
||||
recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP)
|
||||
}
|
||||
|
||||
func recalcTCPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
|
||||
const offsetInsideHeader = 16
|
||||
recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP)
|
||||
}
|
||||
|
||||
func calcNewICMPChecksum(oldChecksum uint16, oldCode uint16, newCode uint16, oldID uint16, newID uint16) uint16 {
|
||||
// Start with inverted checksum
|
||||
checksum := uint32(^oldChecksum)
|
||||
|
||||
// Subtract old stuff
|
||||
checksum += uint32(^oldCode)
|
||||
checksum += uint32(^oldID)
|
||||
|
||||
// Add new stuff
|
||||
checksum += uint32(newCode)
|
||||
checksum += uint32(newID)
|
||||
|
||||
// Fold carries
|
||||
for checksum > 0xFFFF {
|
||||
checksum = (checksum & 0xFFFF) + (checksum >> 16)
|
||||
}
|
||||
|
||||
// Return ones' complement
|
||||
return ^uint16(checksum)
|
||||
}
|
||||
|
||||
func recalcICMPv4Checksum(data []byte, oldCode uint16, newCode uint16, oldID uint16, newID uint16) {
|
||||
const offsetInsideHeader = 2
|
||||
ipHeaderOffset := int(data[0]&0x0F) * 4
|
||||
offset := ipHeaderOffset + offsetInsideHeader
|
||||
oldChecksum := binary.BigEndian.Uint16(data[offset : offset+2])
|
||||
checksum := calcNewICMPChecksum(oldChecksum, oldCode, newCode, oldID, newID)
|
||||
binary.BigEndian.PutUint16(data[offset:offset+2], checksum)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue