This commit is contained in:
Jack Doan 2026-01-22 17:32:44 +00:00 committed by GitHub
commit e25d2aeb73
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 631 additions and 106 deletions

View file

@ -441,7 +441,7 @@ func (c *certificateV2) validate() error {
}
} else if network.Addr().Is4() {
if !hasV4Networks {
return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
//return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
}
}
}

View file

@ -2,6 +2,7 @@ package nebula
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
@ -26,6 +27,19 @@ type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
}
type snatInfo struct {
//Src is the source IP+port to write into unsafe-route-bound packet
Src netip.AddrPort
//SrcVpnIp is the overlay IP associated with this flow. It's needed to associate reply traffic so we can get it back to the right host.
SrcVpnIp netip.Addr
//SnatPort is the port to rewrite into an overlay-bound packet
SnatPort uint16
}
func (s *snatInfo) Valid() bool {
return s.Src.IsValid()
}
type conn struct {
Expires time.Time // Time when this conntrack entry will expire
@ -34,6 +48,9 @@ type conn struct {
// fields pack for free after the uint32 above
incoming bool
rulesVersion uint16
//for SNAT support
snat snatInfo
}
// TODO: need conntrack max tracked connections handling
@ -66,6 +83,7 @@ type Firewall struct {
defaultLocalCIDRAny bool
incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
snatAddr netip.Addr
l *logrus.Logger
}
@ -149,12 +167,14 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
tmax = defaultTimeout
}
hasV4Networks := false
routableNetworks := new(bart.Lite)
var assignedNetworks []netip.Prefix
for _, network := range c.Networks() {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
routableNetworks.Insert(nprefix)
assignedNetworks = append(assignedNetworks, network)
hasV4Networks = hasV4Networks || network.Addr().Is4()
}
hasUnsafeNetworks := false
@ -163,6 +183,11 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
hasUnsafeNetworks = true
}
snatAddr := netip.Addr{}
if hasUnsafeNetworks && !hasV4Networks {
snatAddr = netip.MustParseAddr("169.254.55.96") //todo this needs to come from the config, or perhaps the tun
}
return &Firewall{
Conntrack: &FirewallConntrack{
Conns: make(map[firewall.Packet]*conn),
@ -176,6 +201,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
routableNetworks: routableNetworks,
assignedNetworks: assignedNetworks,
hasUnsafeNetworks: hasUnsafeNetworks,
snatAddr: snatAddr,
l: l,
incomingMetrics: firewallMetrics{
@ -401,22 +427,131 @@ var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn) netip.Addr {
if c == nil {
//unfortunately this needs to lock. Surely there's a better way, but I need to make this flow at all first.
c = f.peek(*fp)
}
if c == nil {
return netip.Addr{}
}
if !c.snat.Valid() {
return netip.Addr{}
}
oldIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort)
//change dst IP
copy(data[16:], c.snat.Src.Addr().AsSlice())
recalcIPv4Checksum(data, oldIP.Addr(), c.snat.Src.Addr())
ipHeaderLen := int(data[0]&0x0F) * 4
//dst port is at offset 2
dstport := ipHeaderLen + 2
switch fp.Protocol {
case firewall.ProtoICMP:
binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], c.snat.Src.Port())
icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on this yet (but Linux would)
recalcICMPv4Checksum(data, icmpCode, icmpCode, c.snat.SnatPort, c.snat.Src.Port())
case firewall.ProtoUDP:
binary.BigEndian.PutUint16(data[dstport:dstport+2], c.snat.Src.Port())
recalcUDPv4Checksum(data, oldIP, c.snat.Src)
case firewall.ProtoTCP:
binary.BigEndian.PutUint16(data[dstport:dstport+2], c.snat.Src.Port())
recalcTCPv4Checksum(data, oldIP, c.snat.Src)
}
return c.snat.SrcVpnIp
}
func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) {
if c.snat.Valid() {
//old flow
fp.RemoteAddr = f.snatAddr
fp.RemotePort = c.snat.SnatPort
} else if hostinfo.vpnAddrs[0].Is6() {
//we got a new flow
c.snat.Src = netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort)
c.snat.SrcVpnIp = hostinfo.vpnAddrs[0]
fp.RemoteAddr = f.snatAddr
//find a new port to use, if needed
for {
existingFlow := f.peek(*fp) //locking and unlocking for each peek is slow, but simple for now
if existingFlow == nil {
break //yay, we can use this port
}
//increment and retry. There's probably better strategies out there
fp.RemotePort++
if fp.RemotePort < 0x7ff {
fp.RemotePort += 0x7ff // keep it ephemeral for now
} //todo if we're totally out of ports this loops forever. Probably not good.
}
c.snat.SnatPort = fp.RemotePort
} else {
f.l.WithFields(logrus.Fields{
"fp": *fp,
"conn": *c,
"hostinfo": hostinfo,
}).Error("this packet cannot be snatted")
return
}
newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort)
//change src IP
copy(data[12:], f.snatAddr.AsSlice())
recalcIPv4Checksum(data, c.snat.Src.Addr(), newIP.Addr())
ipHeaderLen := int(data[0]&0x0F) * 4
switch fp.Protocol {
case firewall.ProtoICMP:
binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], c.snat.SnatPort)
icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on this yet (but Linux would)
recalcICMPv4Checksum(data, icmpCode, icmpCode, c.snat.Src.Port(), c.snat.SnatPort)
case firewall.ProtoUDP:
//src port is at offset 0
binary.BigEndian.PutUint16(data[ipHeaderLen:ipHeaderLen+2], c.snat.SnatPort)
recalcUDPv4Checksum(data, c.snat.Src, newIP)
case firewall.ProtoTCP:
//src port is at offset 0
binary.BigEndian.PutUint16(data[ipHeaderLen:ipHeaderLen+2], c.snat.SnatPort)
recalcTCPv4Checksum(data, c.snat.Src, newIP)
}
}
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
specialSnatMode := f.hasUnsafeNetworks && fp.IsIPv4() && h.HasOnlyV6Addresses() //todo I wish I only set this once somehow
table := f.OutRules
if incoming {
table = f.InRules
}
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(fp, h, caPool, localCache) {
return nil
// Check the cache first, iff not snatting
if localCache != nil && !specialSnatMode {
if _, ok := localCache[fp]; ok {
return nil //packet matched the cache, we're not snatting, we can return early!
}
}
c := f.inConns(fp, h, caPool, localCache)
if c != nil {
//can't return yet, need to snat maybe
goto snat
}
// Make sure remote address matches nebula certificate, and determine how to treat it
if h.networks == nil {
// Simple case: Certificate has one address and no unsafe networks
if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
if !specialSnatMode {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP //todo!
} //else we're in special snat mode, and we need to apply more checks below
} //else? all good, fall through
} else {
//todo check for srcsnortaddr here too?
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
if !ok {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
@ -426,9 +561,12 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
case NetworkTypeVPN:
break // nothing special
case NetworkTypeVPNPeer:
//todo we might need a specialSnatMode case in here to handle routers with v4 addresses when we don't also have a v4 address?
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
//intentionally excluding f.hasUnsafeNetworks -- this is what lets routers talk back to us with our unsafe traffic!
specialSnatMode = fp.IsIPv4() && h.HasOnlyV6Addresses() && f.assignedNetworks[0].Addr().Is6()
break // nothing special, one day this may have different FW rules
default:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
@ -437,16 +575,12 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
}
// Make sure we are supposed to be handling this local ip address
if !f.routableNetworks.Contains(fp.LocalAddr) {
//todo I'm not sure I trust this heuristic
if !specialSnatMode && !f.routableNetworks.Contains(fp.LocalAddr) {
f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP
}
table := f.OutRules
if incoming {
table = f.InRules
}
// We now know which firewall table to check against
if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
f.metrics(incoming).droppedNoRule.Inc(1)
@ -454,7 +588,16 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
}
// We always want to conntrack since it is a faster operation
f.addConn(fp, incoming)
c = f.addConn(fp, incoming)
snat:
if incoming {
if specialSnatMode {
//todo do not snat if you are not a router for the destination -- for now, just if you're not a router
f.applySnat(pkt, &fp, c, h)
f.dupeConn(fp, c) //track the snatted flow with the same expiration as the unsnatted version
}
} //outgoing snat is handled before this function is called (for now!)
return nil
}
@ -483,12 +626,23 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
}
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
}
func (f *Firewall) peek(fp firewall.Packet) *conn {
conntrack := f.Conntrack
conntrack.Lock()
// Purge every time we test
ep, has := conntrack.TimerWheel.Purge()
if has {
f.evict(ep)
}
c := conntrack.Conns[fp]
conntrack.Unlock()
return c
}
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn {
conntrack := f.Conntrack
conntrack.Lock()
@ -500,9 +654,22 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
c, ok := conntrack.Conns[fp]
if !ok && fp.Protocol == firewall.ProtoICMP {
//todo this seems like it will also bite me
oldRemote := fp.RemotePort
oldLocal := fp.LocalPort
fp.RemotePort = 0
fp.LocalPort = 0
c, ok = conntrack.Conns[fp]
if ok {
fp.RemotePort = oldRemote
fp.LocalPort = oldLocal
}
}
if !ok {
conntrack.Unlock()
return false
return nil
}
if c.rulesVersion != f.rulesVersion {
@ -525,7 +692,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
}
delete(conntrack.Conns, fp)
conntrack.Unlock()
return false
return nil
}
if f.l.Level >= logrus.DebugLevel {
@ -555,12 +722,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
localCache[fp] = struct{}{}
}
return true
return c
}
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
func (f *Firewall) packetTimeout(fp firewall.Packet) time.Duration {
var timeout time.Duration
c := &conn{}
switch fp.Protocol {
case firewall.ProtoTCP:
@ -570,7 +736,25 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
default:
timeout = f.DefaultTimeout
}
return timeout
}
func (f *Firewall) dupeConn(fp firewall.Packet, c *conn) {
conntrack := f.Conntrack
conntrack.Lock()
if _, ok := conntrack.Conns[fp]; !ok {
conntrack.TimerWheel.Advance(time.Now())
conntrack.TimerWheel.Add(fp, f.packetTimeout(fp))
}
conntrack.Conns[fp] = c
conntrack.Unlock()
}
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) *conn {
c := &conn{}
timeout := f.packetTimeout(fp)
conntrack := f.Conntrack
conntrack.Lock()
if _, ok := conntrack.Conns[fp]; !ok {
@ -584,7 +768,18 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout)
conntrack.Conns[fp] = c
//todo will this bite me somehow?
if fp.Protocol == firewall.ProtoICMP {
//not required for ICMPv6 because we don't decode or SNAT it
//create a duplicate conntrack entry with all the port information zeroed?
fp.RemotePort = 0
fp.LocalPort = 0
conntrack.Conns[fp] = c
}
conntrack.Unlock()
return c
}
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
@ -662,6 +857,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
var port int32
if p.Protocol == firewall.ProtoICMP {
// port numbers are re-used for connection tracking and SNAT,
// but we don't want to actually filter on them for ICMP
// ICMP6 is omitted because we don't attempt to parse code/identifier/etc out of ICMP6
return fp[firewall.PortAny].match(p, c, caPool)
}
if p.Fragment {
port = firewall.PortFragment
} else if incoming {

View file

@ -22,12 +22,20 @@ const (
type Packet struct {
LocalAddr netip.Addr
RemoteAddr netip.Addr
LocalPort uint16
// LocalPort is the destination port for incoming traffic, or the source port for outgoing.
// For ICMP, it's "code". //todo also store "type?" would need to decode replies, which sucks
LocalPort uint16
// RemotePort is the source port for incoming traffic, or the destination port for outgoing.
// For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier
RemotePort uint16
Protocol uint8
Fragment bool
}
func (fp *Packet) IsIPv4() bool {
return fp.LocalAddr.Is4() && fp.RemoteAddr.Is4()
}
func (fp *Packet) Copy() *Packet {
return &Packet{
LocalAddr: fp.LocalAddr,

View file

@ -212,44 +212,44 @@ func TestFirewall_Drop(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func TestFirewall_DropV6(t *testing.T) {
@ -291,44 +291,44 @@ func TestFirewall_DropV6(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("fd12::56")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
@ -536,10 +536,10 @@ func TestFirewall_Drop2(t *testing.T) {
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
require.ErrorIs(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrNoMatchingRule)
// c has the proper groups
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func TestFirewall_Drop3(t *testing.T) {
@ -617,18 +617,18 @@ func TestFirewall_Drop3(t *testing.T) {
cp := cert.NewCAPool()
// c1 should pass because host match
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
// c2 should pass because ca sha match
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h2, cp, nil))
// c3 should fail because no match
resetConntrack(fw)
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule)
// Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
}
func TestFirewall_Drop3V6(t *testing.T) {
@ -666,7 +666,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func TestFirewall_DropConntrackReload(t *testing.T) {
@ -708,12 +708,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
@ -722,7 +722,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
@ -731,7 +731,160 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
fw.rulesVersion = oldFw.rulesVersion + 1
// Drop outbound because conntrack doesn't match new ruleset
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_ICMPPortBehavior(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
network := netip.MustParsePrefix("1.2.3.4/24")
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host1",
networks: []netip.Prefix{network},
groups: []string{"default-group"},
issuer: "signer-shasum",
},
InvertedGroups: map[string]struct{}{"default-group": {}},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(myVpnNetworksTable, c.Certificate)
cp := cert.NewCAPool()
templ := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
Protocol: firewall.ProtoICMP,
Fragment: false,
}
t.Run("ICMP allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
t.Run("nonzero ports", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
})
t.Run("Any proto, some ports allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 80
p.RemotePort = 80
// Drop outbound
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
})
t.Run("Any proto, any port", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, allowed", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
t.Run("nonzero ports, allowed", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
t.Run("nonzero ports, allowed", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound with a different ID
p.RemotePort++
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
})
}
func TestFirewall_DropIPSpoofing(t *testing.T) {
@ -777,7 +930,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
Protocol: firewall.ProtoUDP,
Fragment: false,
}
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrInvalidRemoteIP)
}
func BenchmarkLookup(b *testing.B) {
@ -1184,7 +1337,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
t.Helper()
cp := cert.NewCAPool()
resetConntrack(fw)
err := fw.Drop(c.p, true, c.h, cp, nil)
err := fw.Drop(c.p, nil, true, c.h, cp, nil)
if c.err == nil {
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
} else {

View file

@ -224,6 +224,7 @@ const (
NetworkTypeVPNPeer
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
NetworkTypeUnsafe
//todo consider NetworkTypeLinkLocal or NetworkTypeSNAT
)
type HostInfo struct {
@ -277,6 +278,15 @@ type HostInfo struct {
lastUsed time.Time
}
func (i *HostInfo) HasOnlyV6Addresses() bool {
for _, vpnIp := range i.vpnAddrs {
if !vpnIp.Is6() {
return false
}
}
return true
}
type ViaSender struct {
UdpAddr netip.AddrPort
relayHI *HostInfo // relayHI is the host info object of the relay

View file

@ -48,9 +48,24 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
var hostinfo *HostInfo
var ready bool
snatMode := f.firewall.snatAddr.IsValid() && fwPacket.RemoteAddr == f.firewall.snatAddr
if snatMode {
//todo unsnat happens here, would be nice to not
destVpnAddr := f.firewall.unSnat(packet, fwPacket, nil) //todo bail if we can't unsnat?
if destVpnAddr.IsValid() {
hostinfo, ready = f.getOrHandshakeNoRouting(destVpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
} //otherwise, hostinfo will be nil
} else { //if we didn't need to unsnat
hostinfo, ready = f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
}
if hostinfo == nil {
f.rejectInside(packet, out, q)
@ -66,10 +81,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
dropReason := f.firewall.Drop(*fwPacket, packet, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil {
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
} else {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
@ -218,7 +232,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
dropReason := f.firewall.Drop(*fp, nil, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).

View file

@ -135,7 +135,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
deviceFactory = overlay.NewDeviceFromConfig
}
tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
cs := pki.getCertState()
tun, err = deviceFactory(c, l, cs.myVpnNetworks, cs.GetDefaultCertificate().UnsafeNetworks(), routines)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
}

View file

@ -329,7 +329,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
switch proto {
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
fp.Protocol = uint8(proto)
fp.RemotePort = 0
fp.RemotePort = 0 //we don't attempt to parse ICMPv6 because we don't SNAT it
fp.LocalPort = 0
fp.Fragment = false
return nil
@ -434,22 +434,28 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
if incoming {
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
if fp.Fragment {
fp.RemotePort = 0
fp.LocalPort = 0
} else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier
fp.LocalPort = uint16(data[ihl+1]) //code
} else {
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2])
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
}
} else {
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
if fp.Fragment {
fp.RemotePort = 0
fp.LocalPort = 0
} else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier
fp.LocalPort = uint16(data[ihl+1]) //code
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2])
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
}
}
@ -494,7 +500,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
dropReason := f.firewall.Drop(*fwPacket, out, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in

View file

@ -13,22 +13,22 @@ import (
const DefaultMTU = 1300
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
switch {
case c.GetBool("tun.disabled", false):
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
return tun, nil
default:
return newTun(c, l, vpnNetworks, routines > 1)
return newTun(c, l, vpnNetworks, unsafeNetworks, routines > 1)
}
}
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks)
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks, unsafeNetworks)
}
}

View file

@ -26,7 +26,7 @@ type tun struct {
l *logrus.Logger
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil
}
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android")
}

View file

@ -79,7 +79,7 @@ type ifreqAlias6 struct {
Lifetime addrLifetime
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "")
ifIndex := -1
if name != "" && name != "utun" {
@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}

View file

@ -199,11 +199,11 @@ func (t *tun) Close() error {
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var fd int
var err error

View file

@ -28,11 +28,11 @@ type tun struct {
l *logrus.Logger
}
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{
vpnNetworks: vpnNetworks,

View file

@ -26,14 +26,15 @@ import (
type tun struct {
io.ReadWriteCloser
fd int
Device string
vpnNetworks []netip.Prefix
MaxMTU int
DefaultMTU int
TXQueueLen int
deviceIndex int
ioctlFd uintptr
fd int
Device string
vpnNetworks []netip.Prefix
unsafeNetworks []netip.Prefix
MaxMTU int
DefaultMTU int
TXQueueLen int
deviceIndex int
ioctlFd uintptr
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
@ -71,10 +72,10 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks)
if err != nil {
return nil, err
}
@ -84,7 +85,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@ -119,7 +120,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks)
if err != nil {
return nil, err
}
@ -129,11 +130,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
return t, nil
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
vpnNetworks: vpnNetworks,
unsafeNetworks: unsafeNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
@ -423,6 +425,27 @@ func (t *tun) setMTU() {
}
}
func (t *tun) setSnatRoute() error {
snataddr := netip.MustParsePrefix("169.254.55.96/32") //todo get this from elsewhere? Or maybe we should pick it, and feed it back out to the firewall?
dr := &net.IPNet{
IP: snataddr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(snataddr.Bits(), snataddr.Addr().BitLen()),
}
nr := netlink.Route{
LinkIndex: t.deviceIndex,
Dst: dr,
//todo do we need these other options?
//MTU: t.DefaultMTU,
//AdvMSS: t.advMSS(Route{}),
Scope: unix.RT_SCOPE_LINK,
//Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST,
}
return netlink.RouteReplace(&nr)
}
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
dr := &net.IPNet{
IP: cidr.Masked().Addr().AsSlice(),
@ -499,6 +522,18 @@ func (t *tun) addRoutes(logErrors bool) error {
}
}
onlyV6Addresses := false
for _, n := range t.vpnNetworks {
if n.Addr().Is6() {
onlyV6Addresses = true
break
}
}
if len(t.unsafeNetworks) != 0 && onlyV6Addresses {
return t.setSnatRoute()
}
return nil
}

View file

@ -70,11 +70,11 @@ type tun struct {
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")

View file

@ -63,11 +63,11 @@ type tun struct {
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")

View file

@ -28,7 +28,7 @@ type TestTun struct {
TxPackets chan []byte // Packets transmitted outside by nebula
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
if err != nil {
return nil, err
@ -49,7 +49,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}, nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported")
}

View file

@ -38,11 +38,11 @@ type winTun struct {
tun *wintun.NativeTun
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*winTun, error) {
err := checkWinTunExists()
if err != nil {
return nil, fmt.Errorf("can not load the wintun driver: %w", err)

View file

@ -9,7 +9,7 @@ import (
"github.com/slackhq/nebula/routing"
)
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
return NewUserDevice(vpnNetworks)
}

View file

@ -12,7 +12,12 @@ import (
// - https://github.com/skeeto/hash-prospector
// [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501
func hashPacket(p *firewall.Packet) int {
x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort)
var x uint32
if p.Protocol == firewall.ProtoICMP {
x = 0 //Don't attempt to use ICMP's code/id/etc to balance
} else {
x = (uint32(p.LocalPort) << 16) | uint32(p.RemotePort)
}
x ^= x >> 16
x *= 0x21f0aaad
x ^= x >> 15

91
snat.go Normal file
View file

@ -0,0 +1,91 @@
package nebula
import (
"encoding/binary"
"net/netip"
)
func recalcIPv4Checksum(data []byte, oldSrcIP netip.Addr, newSrcIP netip.Addr) {
oldChecksum := binary.BigEndian.Uint16(data[10:12])
//because of how checksums work, we can re-use this function
checksum := calcNewTransportChecksum(oldChecksum, oldSrcIP, 0, newSrcIP, 0)
binary.BigEndian.PutUint16(data[10:12], checksum)
}
func calcNewTransportChecksum(oldChecksum uint16, oldSrcIP netip.Addr, oldSrcPort uint16, newSrcIP netip.Addr, newSrcPort uint16) uint16 {
oldIP := binary.BigEndian.Uint32(oldSrcIP.AsSlice())
newIP := binary.BigEndian.Uint32(newSrcIP.AsSlice())
// Start with inverted checksum
checksum := uint32(^oldChecksum)
// Subtract old IP (as two 16-bit words)
checksum += uint32(^uint16(oldIP >> 16))
checksum += uint32(^uint16(oldIP & 0xFFFF))
// Subtract old port
checksum += uint32(^oldSrcPort)
// Add new IP (as two 16-bit words)
checksum += uint32(newIP >> 16)
checksum += uint32(newIP & 0xFFFF)
// Add new port
checksum += uint32(newSrcPort)
// Fold carries
for checksum > 0xFFFF {
checksum = (checksum & 0xFFFF) + (checksum >> 16)
}
// Return ones' complement
return ^uint16(checksum)
}
func recalcV4TransportChecksum(offsetInsideHeader int, data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
ipHeaderOffset := int(data[0]&0x0F) * 4
offset := ipHeaderOffset + offsetInsideHeader
oldcsum := binary.BigEndian.Uint16(data[offset : offset+2])
checksum := calcNewTransportChecksum(oldcsum, oldSrcIP.Addr(), oldSrcIP.Port(), newSrcIP.Addr(), newSrcIP.Port())
binary.BigEndian.PutUint16(data[offset:offset+2], checksum)
}
func recalcUDPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
const offsetInsideHeader = 6
recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP)
}
func recalcTCPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
const offsetInsideHeader = 16
recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP)
}
func calcNewICMPChecksum(oldChecksum uint16, oldCode uint16, newCode uint16, oldID uint16, newID uint16) uint16 {
// Start with inverted checksum
checksum := uint32(^oldChecksum)
// Subtract old stuff
checksum += uint32(^oldCode)
checksum += uint32(^oldID)
// Add new stuff
checksum += uint32(newCode)
checksum += uint32(newID)
// Fold carries
for checksum > 0xFFFF {
checksum = (checksum & 0xFFFF) + (checksum >> 16)
}
// Return ones' complement
return ^uint16(checksum)
}
func recalcICMPv4Checksum(data []byte, oldCode uint16, newCode uint16, oldID uint16, newID uint16) {
const offsetInsideHeader = 2
ipHeaderOffset := int(data[0]&0x0F) * 4
offset := ipHeaderOffset + offsetInsideHeader
oldChecksum := binary.BigEndian.Uint16(data[offset : offset+2])
checksum := calcNewICMPChecksum(oldChecksum, oldCode, newCode, oldID, newID)
binary.BigEndian.PutUint16(data[offset:offset+2], checksum)
}