This commit is contained in:
Jack Doan 2026-03-06 13:05:02 -06:00 committed by GitHub
commit 143d230e2d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 2949 additions and 235 deletions

View file

@ -439,11 +439,7 @@ func (c *certificateV2) validate() error {
if !hasV6Networks {
return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network)
}
} else if network.Addr().Is4() {
if !hasV4Networks {
return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
}
}
} // as long as we have any IP address, IPv4 UnsafeNetworks are allowed
}
}

400
e2e/snat_test.go Normal file
View file

@ -0,0 +1,400 @@
//go:build e2e_testing
// +build e2e_testing
package e2e
import (
"encoding/binary"
"net/netip"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/e2e/router"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// parseIPv4UDPPacket extracts source/dest IPs, ports, and payload from an IPv4 UDP packet.
func parseIPv4UDPPacket(t testing.TB, pkt []byte) (srcIP, dstIP netip.Addr, srcPort, dstPort uint16, payload []byte) {
t.Helper()
require.True(t, len(pkt) >= 28, "packet too short for IPv4+UDP header")
require.Equal(t, byte(0x45), pkt[0]&0xF0|pkt[0]&0x0F, "not a simple IPv4 packet (IHL!=5)")
srcIP, _ = netip.AddrFromSlice(pkt[12:16])
dstIP, _ = netip.AddrFromSlice(pkt[16:20])
ihl := int(pkt[0]&0x0F) * 4
require.True(t, len(pkt) >= ihl+8, "packet too short for UDP header")
srcPort = binary.BigEndian.Uint16(pkt[ihl : ihl+2])
dstPort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4])
udpLen := binary.BigEndian.Uint16(pkt[ihl+4 : ihl+6])
payload = pkt[ihl+8 : ihl+int(udpLen)]
return
}
func TestSNAT_IPv6OnlyPeer_IPv4UnsafeTraffic(t *testing.T) {
// Scenario: Two IPv6-only VPN nodes. The "router" node has unsafe networks
// (192.168.0.0/16) in its cert and a configured SNAT address. The "sender"
// node has an unsafe route for 192.168.0.0/16 via the router.
//
// When sender injects an IPv4 packet destined for the unsafe network, it
// gets tunneled to the router. The router's firewall detects this is IPv4
// from an IPv6-only peer and applies SNAT, rewriting the source IP to the
// SNAT address before delivering it to TUN.
//
// When a reply comes back from TUN addressed to the SNAT address, the
// router un-SNATs it (restoring the original destination) and tunnels it
// back to the sender.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
unsafePrefix := "192.168.0.0/16"
snatAddr := netip.MustParseAddr("169.254.42.42")
// Router: IPv6-only with unsafe networks and a manual SNAT address.
// Override inbound firewall with local_cidr: "any" so both IPv4 (unsafe)
// and IPv6 (VPN) traffic is accepted.
routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(
cert.Version2, ca, caKey, "router", "ff::1/64",
netip.MustParseAddrPort("[beef::1]:4242"),
unsafePrefix,
m{
"firewall": m{
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
"local_cidr": "any",
}},
},
"tun": m{
"snat_address_for_4over6": snatAddr.String(),
},
},
)
// Sender: IPv6-only with an unsafe route via the router
senderControl, _, _, _ := newSimpleServerWithUdp(
cert.Version2, ca, caKey, "sender", "ff::2/64",
netip.MustParseAddrPort("[beef::2]:4242"),
m{
"tun": m{
"unsafe_routes": []m{
{"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()},
},
},
},
)
// Tell sender where the router lives
senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr)
// Build the router and start both nodes
r := router.NewR(t, routerControl, senderControl)
defer r.RenderFlow()
routerControl.Start()
senderControl.Start()
// --- Outbound: sender -> IPv4 unsafe dest (via router with SNAT) ---
origSrcIP := netip.MustParseAddr("10.0.0.1")
unsafeDest := netip.MustParseAddr("192.168.1.1")
var origSrcPort uint16 = 12345
var dstPort uint16 = 80
t.Log("Sender injects an IPv4 packet to the unsafe network")
senderControl.InjectTunUDPPacket(unsafeDest, dstPort, origSrcIP, origSrcPort, []byte("snat me"))
t.Log("Route packets (handshake + data) until the router gets the packet on TUN")
snatPkt := r.RouteForAllUntilTxTun(routerControl)
t.Log("Verify the packet was SNATted")
gotSrcIP, gotDstIP, gotSrcPort, gotDstPort, gotPayload := parseIPv4UDPPacket(t, snatPkt)
assert.Equal(t, snatAddr, gotSrcIP, "source IP should be rewritten to the SNAT address")
assert.Equal(t, unsafeDest, gotDstIP, "destination IP should be unchanged")
assert.Equal(t, dstPort, gotDstPort, "destination port should be unchanged")
assert.Equal(t, []byte("snat me"), gotPayload, "payload should be unchanged")
// Capture the SNAT port (may differ from original if port was remapped)
snatPort := gotSrcPort
t.Logf("SNAT port: %d (original: %d)", snatPort, origSrcPort)
// --- Return: reply from unsafe dest -> un-SNATted back to sender ---
t.Log("Router injects a reply packet from the unsafe dest to the SNAT address")
routerControl.InjectTunUDPPacket(snatAddr, snatPort, unsafeDest, dstPort, []byte("reply from unsafe"))
t.Log("Route until sender gets the reply on TUN")
replyPkt := r.RouteForAllUntilTxTun(senderControl)
t.Log("Verify the reply was un-SNATted")
replySrcIP, replyDstIP, replySrcPort, replyDstPort, replyPayload := parseIPv4UDPPacket(t, replyPkt)
assert.Equal(t, unsafeDest, replySrcIP, "reply source should be the unsafe dest")
assert.Equal(t, origSrcIP, replyDstIP, "reply dest should be the original source IP (un-SNATted)")
assert.Equal(t, dstPort, replySrcPort, "reply source port should be the unsafe dest port")
assert.Equal(t, origSrcPort, replyDstPort, "reply dest port should be the original source port (un-SNATted)")
assert.Equal(t, []byte("reply from unsafe"), replyPayload, "payload should be unchanged")
r.RenderHostmaps("Final hostmaps", routerControl, senderControl)
// Also verify normal IPv6 VPN traffic still works between the nodes
t.Log("Verify normal IPv6 VPN tunnel still works")
assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r)
routerControl.Stop()
senderControl.Stop()
}
func TestSNAT_MultipleFlows(t *testing.T) {
// Test that multiple distinct IPv4 flows from the same IPv6-only peer
// are tracked independently through SNAT.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
unsafePrefix := "192.168.0.0/16"
snatAddr := netip.MustParseAddr("169.254.42.42")
routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(
cert.Version2, ca, caKey, "router", "ff::1/64",
netip.MustParseAddrPort("[beef::1]:4242"),
unsafePrefix,
m{
"firewall": m{
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
"local_cidr": "any",
}},
},
"tun": m{
"snat_address_for_4over6": snatAddr.String(),
},
},
)
senderControl, _, _, _ := newSimpleServerWithUdp(
cert.Version2, ca, caKey, "sender", "ff::2/64",
netip.MustParseAddrPort("[beef::2]:4242"),
m{
"tun": m{
"unsafe_routes": []m{
{"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()},
},
},
},
)
senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr)
r := router.NewR(t, routerControl, senderControl)
defer r.RenderFlow()
r.CancelFlowLogs()
routerControl.Start()
senderControl.Start()
unsafeDest := netip.MustParseAddr("192.168.1.1")
// Send first flow
senderControl.InjectTunUDPPacket(unsafeDest, 80, netip.MustParseAddr("10.0.0.1"), 1111, []byte("flow1"))
pkt1 := r.RouteForAllUntilTxTun(routerControl)
srcIP1, _, srcPort1, _, payload1 := parseIPv4UDPPacket(t, pkt1)
assert.Equal(t, snatAddr, srcIP1)
assert.Equal(t, []byte("flow1"), payload1)
// Send second flow (different source port)
senderControl.InjectTunUDPPacket(unsafeDest, 80, netip.MustParseAddr("10.0.0.1"), 2222, []byte("flow2"))
pkt2 := r.RouteForAllUntilTxTun(routerControl)
srcIP2, _, srcPort2, _, payload2 := parseIPv4UDPPacket(t, pkt2)
assert.Equal(t, snatAddr, srcIP2)
assert.Equal(t, []byte("flow2"), payload2)
// The two flows should have different SNAT ports (since they're different conntracks)
t.Logf("Flow 1 SNAT port: %d, Flow 2 SNAT port: %d", srcPort1, srcPort2)
// Reply to flow 2 first (out of order)
routerControl.InjectTunUDPPacket(snatAddr, srcPort2, unsafeDest, 80, []byte("reply2"))
reply2 := r.RouteForAllUntilTxTun(senderControl)
_, replyDst2, _, replyDstPort2, replyPayload2 := parseIPv4UDPPacket(t, reply2)
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), replyDst2)
assert.Equal(t, uint16(2222), replyDstPort2, "reply to flow 2 should restore original port 2222")
assert.Equal(t, []byte("reply2"), replyPayload2)
// Reply to flow 1
routerControl.InjectTunUDPPacket(snatAddr, srcPort1, unsafeDest, 80, []byte("reply1"))
reply1 := r.RouteForAllUntilTxTun(senderControl)
_, replyDst1, _, replyDstPort1, replyPayload1 := parseIPv4UDPPacket(t, reply1)
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), replyDst1)
assert.Equal(t, uint16(1111), replyDstPort1, "reply to flow 1 should restore original port 1111")
assert.Equal(t, []byte("reply1"), replyPayload1)
routerControl.Stop()
senderControl.Stop()
}
// --- Adversarial SNAT E2E Tests ---
func TestSNAT_UnsolicitedReplyDropped(t *testing.T) {
// Without any outbound SNAT traffic, inject a packet from the router's TUN
// addressed to the SNAT address. The sender must never receive it because
// there's no conntrack entry to un-SNAT through.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
unsafePrefix := "192.168.0.0/16"
snatAddr := netip.MustParseAddr("169.254.42.42")
routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(
cert.Version2, ca, caKey, "router", "ff::1/64",
netip.MustParseAddrPort("[beef::1]:4242"),
unsafePrefix,
m{
"firewall": m{
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
"local_cidr": "any",
}},
},
"tun": m{
"snat_address_for_4over6": snatAddr.String(),
},
},
)
senderControl, _, _, _ := newSimpleServerWithUdp(
cert.Version2, ca, caKey, "sender", "ff::2/64",
netip.MustParseAddrPort("[beef::2]:4242"),
m{
"tun": m{
"unsafe_routes": []m{
{"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()},
},
},
},
)
senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr)
r := router.NewR(t, routerControl, senderControl)
defer r.RenderFlow()
r.CancelFlowLogs()
routerControl.Start()
senderControl.Start()
// First establish the tunnel with normal IPv6 traffic so handshake completes
assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r)
// Inject the unsolicited reply from router's TUN to the SNAT address.
// There is NO prior outbound SNAT flow, so no conntrack entry exists.
// The router should silently drop this because unSnat finds no matching conntrack.
routerControl.InjectTunUDPPacket(snatAddr, 55555, netip.MustParseAddr("192.168.1.1"), 80, []byte("unsolicited"))
// Send a canary IPv6 VPN packet after the bad one. Since the router processes
// TUN packets sequentially, the canary arriving proves the bad packet was processed first.
senderVpnAddr := senderControl.GetVpnAddrs()[0]
routerControl.InjectTunUDPPacket(senderVpnAddr, 90, routerVpnIpNet[0].Addr(), 80, []byte("canary"))
canaryPkt := r.RouteForAllUntilTxTun(senderControl)
assertUdpPacket(t, []byte("canary"), canaryPkt, routerVpnIpNet[0].Addr(), senderVpnAddr, 80, 90)
// The unsolicited packet should have been dropped — nothing else on sender's TUN
got := senderControl.GetFromTun(false)
assert.Nil(t, got, "sender should not receive unsolicited packet to SNAT address with no conntrack entry")
routerControl.Stop()
senderControl.Stop()
}
func TestSNAT_NonUnsafeDestDropped(t *testing.T) {
// An IPv6-only sender sends IPv4 traffic to a destination outside the router's
// unsafe networks (172.16.0.1 when unsafe is 192.168.0.0/16). The router should
// reject this because the local address is not routable. This verifies that
// willingToHandleLocalAddr enforces boundaries on what SNAT traffic can reach.
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
unsafePrefix := "192.168.0.0/16"
snatAddr := netip.MustParseAddr("169.254.42.42")
routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(
cert.Version2, ca, caKey, "router", "ff::1/64",
netip.MustParseAddrPort("[beef::1]:4242"),
unsafePrefix,
m{
"firewall": m{
"inbound": []m{{
"proto": "any",
"port": "any",
"host": "any",
"local_cidr": "any",
}},
},
"tun": m{
"snat_address_for_4over6": snatAddr.String(),
},
},
)
// Sender has unsafe routes for BOTH 192.168.0.0/16 AND 172.16.0.0/12 via router.
// This means the sender will route 172.16.0.1 through the tunnel to the router.
// But the router should reject it because 172.16.0.0/12 is NOT in its unsafe networks.
senderControl, _, _, _ := newSimpleServerWithUdp(
cert.Version2, ca, caKey, "sender", "ff::2/64",
netip.MustParseAddrPort("[beef::2]:4242"),
m{
"tun": m{
"unsafe_routes": []m{
{"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()},
{"route": "172.16.0.0/12", "via": routerVpnIpNet[0].Addr().String()},
},
},
},
)
senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr)
r := router.NewR(t, routerControl, senderControl)
defer r.RenderFlow()
r.CancelFlowLogs()
routerControl.Start()
senderControl.Start()
// Establish the tunnel first
assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r)
// Send to 172.16.0.1 (NOT in router's unsafe networks 192.168.0.0/16).
// The router should reject this at willingToHandleLocalAddr.
senderControl.InjectTunUDPPacket(
netip.MustParseAddr("172.16.0.1"), 80,
netip.MustParseAddr("10.0.0.1"), 12345,
[]byte("wrong dest"),
)
// Send a canary to a valid unsafe destination to prove the bad packet was processed
senderControl.InjectTunUDPPacket(
netip.MustParseAddr("192.168.1.1"), 80,
netip.MustParseAddr("10.0.0.1"), 33333,
[]byte("canary"),
)
// Route until the canary arrives — the 172.16.0.1 packet should have been
// processed and dropped before the canary gets through
canaryPkt := r.RouteForAllUntilTxTun(routerControl)
_, canaryDst, _, _, canaryPayload := parseIPv4UDPPacket(t, canaryPkt)
assert.Equal(t, netip.MustParseAddr("192.168.1.1"), canaryDst, "canary should arrive at the valid unsafe dest")
assert.Equal(t, []byte("canary"), canaryPayload)
// No more packets — the 172.16.0.1 packet was dropped
got := routerControl.GetFromTun(false)
assert.Nil(t, got, "packet to non-unsafe destination 172.16.0.1 should be dropped by the router")
routerControl.Stop()
senderControl.Stop()
}

View file

@ -2,6 +2,7 @@ package nebula
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
@ -22,10 +23,35 @@ import (
"github.com/slackhq/nebula/firewall"
)
var ErrCannotSNAT = errors.New("cannot SNAT this packet")
var ErrSNATIdentityMismatch = errors.New("refusing to SNAT for mismatched host")
var ErrSNATAddressCollision = errors.New("refusing to accept an incoming packet with my SNAT address")
const ipv4SourcePosition = 12
const ipv4DestinationPosition = 16
const sourcePortOffset = 0
const destinationPortOffset = 2
type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
}
type snatInfo struct {
//Src is the source IP+port to write into unsafe-route-bound packet
Src netip.AddrPort
//SrcVpnIp is the overlay IP associated with this flow. It's needed to associate reply traffic so we can get it back to the right host.
SrcVpnIp netip.Addr
//SnatPort is the port to rewrite into an overlay-bound packet
SnatPort uint16
}
func (s *snatInfo) Valid() bool {
if s == nil {
return false
}
return s.Src.IsValid()
}
type conn struct {
Expires time.Time // Time when this conntrack entry will expire
@ -34,6 +60,9 @@ type conn struct {
// fields pack for free after the uint32 above
incoming bool
rulesVersion uint16
//for SNAT support
snat *snatInfo
}
// TODO: need conntrack max tracked connections handling
@ -66,6 +95,8 @@ type Firewall struct {
defaultLocalCIDRAny bool
incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
unsafeIPv4Origin netip.Addr
snatAddr netip.Addr
l *logrus.Logger
}
@ -193,12 +224,12 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
if certificate == nil { //todo if config.initiating_version is set to 1, and unsafe_networks differ, things will suck
certificate = cs.getCertificate(cert.Version1)
}
if certificate == nil {
panic("No certificate available to reconfigure the firewall")
return nil, errors.New("no certificate available to reconfigure the firewall")
}
fw := NewFirewall(
@ -207,7 +238,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
certificate,
//TODO: max_connections
)
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
@ -314,6 +344,18 @@ func (f *Firewall) GetRuleHashes() string {
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
}
func (f *Firewall) SetSNATAddressFromInterface(i *Interface) {
//address-mutation-avoidance is done inside Interface, the firewall doesn't need to care
//todo should snatted conntracks get expired out? Probably not needed until if/when we allow reload
f.snatAddr = i.inside.SNATAddress().Addr()
f.unsafeIPv4Origin = i.inside.UnsafeIPv4OriginAddress().Addr()
}
func (f *Firewall) ShouldUnSNAT(fp *firewall.Packet) bool {
// f.snatAddr is only valid if we're a snat-capable router
return f.snatAddr.IsValid() && fp.RemoteAddr == f.snatAddr
}
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
var table string
if inbound {
@ -414,50 +456,207 @@ var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
func (f *Firewall) unSnat(data []byte, fp *firewall.Packet) netip.Addr {
c := f.peek(*fp) //unfortunately this needs to lock. Surely there's a better way.
if c == nil {
return netip.Addr{}
}
if !c.snat.Valid() {
return netip.Addr{}
}
oldIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort)
rewritePacket(data, fp, oldIP, c.snat.Src, ipv4DestinationPosition, destinationPortOffset)
return c.snat.SrcVpnIp
}
func rewritePacket(data []byte, fp *firewall.Packet, oldIP netip.AddrPort, newIP netip.AddrPort, ipOffset int, portOffset int) {
//change address
copy(data[ipOffset:], newIP.Addr().AsSlice())
recalcIPv4Checksum(data, oldIP.Addr(), newIP.Addr())
ipHeaderLen := int(data[0]&0x0F) * 4
switch fp.Protocol {
case firewall.ProtoICMP:
binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], newIP.Port()) //we use the ID field as a "port" for ICMP
icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on code yet (but Linux would)
recalcICMPv4Checksum(data, icmpCode, icmpCode, oldIP.Port(), newIP.Port())
case firewall.ProtoUDP:
dstport := ipHeaderLen + portOffset
binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port())
recalcUDPv4Checksum(data, oldIP, newIP)
case firewall.ProtoTCP:
dstport := ipHeaderLen + portOffset
binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port())
recalcTCPv4Checksum(data, oldIP, newIP)
}
}
func (f *Firewall) findUsableSNATPort(fp *firewall.Packet, c *conn) error {
const halfThePorts = 0x7fff
oldPort := fp.RemotePort
conntrack := f.Conntrack
conntrack.Lock()
defer conntrack.Unlock()
for numPortsChecked := 0; numPortsChecked < halfThePorts; numPortsChecked++ {
_, ok := conntrack.Conns[*fp]
if !ok {
//yay, we can use this port
//track the snatted flow with the same expiration as the unsnatted version
c.snat.SnatPort = fp.RemotePort
conntrack.Conns[*fp] = c
return nil
}
//increment and retry. There's probably better strategies out there
fp.RemotePort++
if fp.RemotePort < halfThePorts {
fp.RemotePort += halfThePorts // keep it ephemeral for now
}
}
//if we made it here, we failed
fp.RemotePort = oldPort
return ErrCannotSNAT
}
func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) error {
if !f.snatAddr.IsValid() {
return ErrCannotSNAT
}
if f.snatAddr == fp.LocalAddr { //a packet that came from UDP (incoming) should never ever have our snat address on it
return ErrSNATAddressCollision
}
if c.snat.Valid() {
//old flow: make sure it came from the right place
if !slices.Contains(hostinfo.vpnAddrs, c.snat.SrcVpnIp) {
return ErrSNATIdentityMismatch
}
fp.RemoteAddr = f.snatAddr
fp.RemotePort = c.snat.SnatPort
} else if hostinfo.vpnAddrs[0].Is6() {
//we got a new flow
c.snat = &snatInfo{
Src: netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort),
SrcVpnIp: hostinfo.vpnAddrs[0],
}
fp.RemoteAddr = f.snatAddr
//find a new port to use, if needed
err := f.findUsableSNATPort(fp, c)
if err != nil {
c.snat = nil
return err
}
} else {
return ErrCannotSNAT
}
newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort)
rewritePacket(data, fp, c.snat.Src, newIP, ipv4SourcePosition, sourcePortOffset)
return nil
}
func (f *Firewall) identifyRemoteNetworkType(h *HostInfo, fp firewall.Packet) NetworkType {
if h.networks == nil {
// Simple case: Certificate has one address and no unsafe networks
if h.vpnAddrs[0] == fp.RemoteAddr {
return NetworkTypeVPN
} //else, fallthrough
} else if nwType, ok := h.networks.Lookup(fp.RemoteAddr); ok {
return nwType //will return NetworkTypeVPN or NetworkTypeUnsafe
}
//RemoteAddr not in our networks table
if f.snatAddr.IsValid() && fp.IsIPv4() && h.HasOnlyV6Addresses() {
return NetworkTypeUnverifiedSNATPeer
} else {
return NetworkTypeInvalidPeer
}
}
func (f *Firewall) allowRemoteNetworkType(nwType NetworkType, fp firewall.Packet) error {
switch nwType {
case NetworkTypeVPN:
return nil
case NetworkTypeInvalidPeer:
return ErrInvalidRemoteIP
case NetworkTypeVPNPeer:
//one day we might need a specialSnatMode case in here to handle routers with v4 addresses when we don't also have a v4 address?
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
return nil // nothing special, one day this may have different FW rules
case NetworkTypeUnverifiedSNATPeer:
if f.unsafeIPv4Origin.IsValid() && fp.LocalAddr == f.unsafeIPv4Origin {
return nil //the client case
}
if f.snatAddr.IsValid() {
if fp.RemoteAddr == f.snatAddr {
return ErrInvalidRemoteIP //we should never get a packet with our SNAT addr as the destination, or "from" our SNAT addr
}
return nil
} else {
return ErrInvalidRemoteIP
}
default:
return ErrUnknownNetworkType //should never happen
}
}
func (f *Firewall) willingToHandleLocalAddr(incoming bool, fp firewall.Packet, remoteNwType NetworkType) error {
if f.routableNetworks.Contains(fp.LocalAddr) {
return nil //easy, this should handle NetworkTypeVPN in all cases, and NetworkTypeUnsafe on the router side
}
if incoming { //at least for now, reject all traffic other than what we've already decided is locally routable
return ErrInvalidLocalIP
}
//below this line, all traffic is outgoing. Outgoing traffic to NetworkTypeUnsafe is not required to be considered inbound-routable
if remoteNwType == NetworkTypeUnsafe {
return nil
}
return ErrInvalidLocalIP
}
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
table := f.OutRules
if incoming {
table = f.InRules
}
snatmode := fp.IsIPv4() && h.HasOnlyV6Addresses() && f.snatAddr.IsValid()
if snatmode {
//if this is an IPv4 packet from a V6 only host, and we're configured to snat that kind of traffic, it must be snatted,
//so it can never be in the localcache, which lacks SNAT data
//nil out the pointer to avoid ever using it
localCache = nil
}
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(fp, h, caPool, localCache) {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return nil //packet matched the cache, we're not snatting, we can return early
}
}
c := f.inConns(fp, h, caPool, localCache)
if c != nil {
if incoming && snatmode {
return f.applySnat(pkt, &fp, c, h)
}
return nil
}
// Make sure remote address matches nebula certificate, and determine how to treat it
if h.networks == nil {
// Simple case: Certificate has one address and no unsafe networks
if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
} else {
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
if !ok {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
switch nwType {
case NetworkTypeVPN:
break // nothing special
case NetworkTypeVPNPeer:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
break // nothing special, one day this may have different FW rules
default:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrUnknownNetworkType //should never happen
}
remoteNetworkType := f.identifyRemoteNetworkType(h, fp)
if err := f.allowRemoteNetworkType(remoteNetworkType, fp); err != nil {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return err
}
// Make sure we are supposed to be handling this local ip address
if !f.routableNetworks.Contains(fp.LocalAddr) {
if err := f.willingToHandleLocalAddr(incoming, fp, remoteNetworkType); err != nil {
f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP
}
table := f.OutRules
if incoming {
table = f.InRules
return err
}
// We now know which firewall table to check against
@ -467,9 +666,14 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
}
// We always want to conntrack since it is a faster operation
f.addConn(fp, incoming)
c = f.addConn(fp, incoming)
return nil
if incoming && remoteNetworkType == NetworkTypeUnverifiedSNATPeer {
return f.applySnat(pkt, &fp, c, h)
} else {
//outgoing snat is handled before this function is called
return nil
}
}
func (f *Firewall) metrics(incoming bool) firewallMetrics {
@ -496,12 +700,14 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
}
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
}
}
func (f *Firewall) peek(fp firewall.Packet) *conn {
f.Conntrack.Lock()
c := f.Conntrack.Conns[fp]
f.Conntrack.Unlock()
return c
}
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn {
conntrack := f.Conntrack
conntrack.Lock()
@ -515,7 +721,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
if !ok {
conntrack.Unlock()
return false
return nil
}
if c.rulesVersion != f.rulesVersion {
@ -538,7 +744,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
}
delete(conntrack.Conns, fp)
conntrack.Unlock()
return false
return nil
}
if f.l.Level >= logrus.DebugLevel {
@ -568,12 +774,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
localCache[fp] = struct{}{}
}
return true
return c
}
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
func (f *Firewall) packetTimeout(fp firewall.Packet) time.Duration {
var timeout time.Duration
c := &conn{}
switch fp.Protocol {
case firewall.ProtoTCP:
@ -583,7 +788,13 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
default:
timeout = f.DefaultTimeout
}
return timeout
}
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) *conn {
c := &conn{}
timeout := f.packetTimeout(fp)
conntrack := f.Conntrack
conntrack.Lock()
if _, ok := conntrack.Conns[fp]; !ok {
@ -597,7 +808,9 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout)
conntrack.Conns[fp] = c
conntrack.Unlock()
return c
}
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel

View file

@ -31,6 +31,10 @@ type Packet struct {
Fragment bool
}
func (fp *Packet) IsIPv4() bool {
return fp.LocalAddr.Is4() && fp.RemoteAddr.Is4()
}
func (fp *Packet) Copy() *Packet {
return &Packet{
LocalAddr: fp.LocalAddr,

View file

@ -213,44 +213,44 @@ func TestFirewall_Drop(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func TestFirewall_DropV6(t *testing.T) {
@ -292,44 +292,44 @@ func TestFirewall_DropV6(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("fd12::56")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
@ -537,10 +537,10 @@ func TestFirewall_Drop2(t *testing.T) {
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
require.ErrorIs(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrNoMatchingRule)
// c has the proper groups
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func TestFirewall_Drop3(t *testing.T) {
@ -618,18 +618,18 @@ func TestFirewall_Drop3(t *testing.T) {
cp := cert.NewCAPool()
// c1 should pass because host match
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
// c2 should pass because ca sha match
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h2, cp, nil))
// c3 should fail because no match
resetConntrack(fw)
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule)
// Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
}
func TestFirewall_Drop3V6(t *testing.T) {
@ -667,7 +667,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func TestFirewall_DropConntrackReload(t *testing.T) {
@ -709,12 +709,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
@ -723,7 +723,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
@ -732,7 +732,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
fw.rulesVersion = oldFw.rulesVersion + 1
// Drop outbound because conntrack doesn't match new ruleset
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_ICMPPortBehavior(t *testing.T) {
@ -778,12 +778,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
t.Run("nonzero ports", func(t *testing.T) {
@ -791,12 +791,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
})
@ -808,12 +808,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero ports, still blocked", func(t *testing.T) {
@ -821,12 +821,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
@ -834,12 +834,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 80
p.RemotePort = 80
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
})
t.Run("Any proto, any port", func(t *testing.T) {
@ -851,12 +851,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
t.Run("nonzero ports, allowed", func(t *testing.T) {
@ -865,15 +865,15 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
//different ID is blocked
p.RemotePort++
require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
require.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
})
@ -922,7 +922,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
Protocol: firewall.ProtoUDP,
Fragment: false,
}
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrInvalidRemoteIP)
}
func BenchmarkLookup(b *testing.B) {
@ -1042,7 +1042,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
l := test.NewLogger()
// Test a bad rule definition
c := &dummyCert{}
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
cs, err := newCertState(l, cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
require.NoError(t, err)
conf := config.NewC(l)
@ -1336,7 +1336,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
t.Helper()
cp := cert.NewCAPool()
resetConntrack(fw)
err := fw.Drop(c.p, true, c.h, cp, nil)
err := fw.Drop(c.p, nil, true, c.h, cp, nil)
if c.err == nil {
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
} else {
@ -1344,7 +1344,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
}
}
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
func buildHostinfo(setup testsetup, theirPrefixes ...netip.Prefix) *HostInfo {
c1 := dummyCert{
name: "host1",
networks: theirPrefixes,
@ -1364,6 +1364,11 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te
h.vpnAddrs[i] = theirPrefixes[i].Addr()
}
h.buildNetworks(setup.myVpnNetworksTable, &c1)
return &h
}
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
h := buildHostinfo(setup, theirPrefixes...)
p := firewall.Packet{
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
RemoteAddr: theirPrefixes[0].Addr(),
@ -1373,9 +1378,9 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te
Fragment: false,
}
return testcase{
h: &h,
h: h,
p: p,
c: &c1,
c: h.ConnectionState.peerCert.Certificate,
err: err,
}
}
@ -1397,6 +1402,19 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse
return newSetupFromCert(t, l, c)
}
func newSnatSetup(t *testing.T, l *logrus.Logger, myPrefix netip.Prefix, snatAddr netip.Addr) testsetup {
c := dummyCert{
name: "me",
networks: []netip.Prefix{myPrefix},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
out := newSetupFromCert(t, l, c)
out.fw.snatAddr = snatAddr
return out
}
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable := new(bart.Lite)
for _, prefix := range c.Networks() {
@ -1532,3 +1550,59 @@ func resetConntrack(fw *Firewall) {
fw.Conntrack.Conns = map[firewall.Packet]*conn{}
fw.Conntrack.Unlock()
}
func TestFirewall_SNAT(t *testing.T) {
t.Parallel()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
cp := cert.NewCAPool()
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
MyCert := dummyCert{
name: "me",
networks: []netip.Prefix{myPrefix},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
theirPrefix := netip.MustParsePrefix("1.2.2.2/8")
snatAddr := netip.MustParseAddr("169.254.55.96")
t.Run("allow inbound all matching", func(t *testing.T) {
t.Parallel()
myCert := MyCert.Copy()
setup := newSnatSetup(t, l, myPrefix, snatAddr)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
resetConntrack(setup.fw)
h := buildHostinfo(setup, theirPrefix)
p := firewall.Packet{
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
RemoteAddr: h.vpnAddrs[0],
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
require.NoError(t, setup.fw.Drop(p, nil, true, h, cp, nil))
})
//t.Run("allow inbound unsafe route", func(t *testing.T) {
// t.Parallel()
// unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
// c := dummyCert{
// name: "me",
// networks: []netip.Prefix{myPrefix},
// unsafeNetworks: []netip.Prefix{unsafePrefix},
// groups: []string{"default-group"},
// issuer: "signer-shasum",
// }
// unsafeSetup := newSetupFromCert(t, l, c)
// tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
// tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
// tc.err = ErrNoMatchingRule
// tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
// require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
// tc.err = nil
// tc.Test(t, unsafeSetup.fw) //should pass
//})
}

View file

@ -224,6 +224,9 @@ const (
NetworkTypeVPNPeer
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
NetworkTypeUnsafe
// NetworkTypeUnverifiedSNATPeer is used to indicate traffic we're willing to route, but never deliver to a NetworkTypeVPN
NetworkTypeUnverifiedSNATPeer
NetworkTypeInvalidPeer
)
type HostInfo struct {
@ -277,6 +280,15 @@ type HostInfo struct {
lastUsed time.Time
}
func (i *HostInfo) HasOnlyV6Addresses() bool {
for _, vpnIp := range i.vpnAddrs {
if !vpnIp.Is6() {
return false
}
}
return true
}
type ViaSender struct {
UdpAddr netip.AddrPort
relayHI *HostInfo // relayHI is the host info object of the relay

View file

@ -48,9 +48,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
hostinfo, ready := f.getHostinfo(packet, fwPacket)
if hostinfo == nil {
f.rejectInside(packet, out, q)
@ -66,10 +64,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
dropReason := f.firewall.Drop(*fwPacket, packet, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil {
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
} else {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
@ -81,6 +78,26 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
}
}
func (f *Interface) getHostinfo(packet []byte, fwPacket *firewall.Packet) (*HostInfo, bool) {
if f.firewall.ShouldUnSNAT(fwPacket) {
//unsnat packet re-writing also happens here, would be nice to not,
//but we need to do the unsnat lookup to find the hostinfo so we can run the firewall checks
destVpnAddr := f.firewall.unSnat(packet, fwPacket)
if destVpnAddr.IsValid() {
//because this was a snatted packet, we know it has an on-overlay destination, so no routing should be required.
return f.getOrHandshakeNoRouting(destVpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
} else {
return nil, false
}
} else { //if we didn't need to unsnat
return f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
}
}
func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
if !f.firewall.InSendReject {
return
@ -218,7 +235,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
dropReason := f.firewall.Drop(*fp, p, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).

View file

@ -248,6 +248,7 @@ func (f *Interface) activate() {
f.inside.Close()
f.l.Fatal(err)
}
f.firewall.SetSNATAddressFromInterface(f)
}
func (f *Interface) run() {
@ -344,6 +345,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
f.l.WithError(err).Error("Error while creating firewall during reload")
return
}
fw.SetSNATAddressFromInterface(f)
oldFw := f.firewall
conntrack := oldFw.Conntrack

View file

@ -131,7 +131,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
deviceFactory = overlay.NewDeviceFromConfig
}
tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
cs := pki.getCertState()
tun, err = deviceFactory(c, l, cs.myVpnNetworks, cs.GetDefaultCertificate().UnsafeNetworks(), routines)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
}

View file

@ -514,7 +514,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
dropReason := f.firewall.Drop(*fwPacket, out, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in

View file

@ -11,6 +11,9 @@ type Device interface {
io.ReadWriteCloser
Activate() error
Networks() []netip.Prefix
UnsafeNetworks() []netip.Prefix
UnsafeIPv4OriginAddress() netip.Prefix
SNATAddress() netip.Prefix
Name() string
RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool

View file

@ -1,7 +1,9 @@
package overlay
import (
"crypto/rand"
"fmt"
"io"
"net"
"net/netip"
@ -22,22 +24,22 @@ func (e *NameError) Error() string {
}
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
switch {
case c.GetBool("tun.disabled", false):
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
return tun, nil
default:
return newTun(c, l, vpnNetworks, routines > 1)
return newTun(c, l, vpnNetworks, unsafeNetworks, routines > 1)
}
}
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks)
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks, unsafeNetworks)
}
}
@ -129,3 +131,85 @@ func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, er
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
}
// genLinkLocal generates a random IPv4 link-local address.
// If randomizer is nil, it uses rand.Reader to find two random bytes
func genLinkLocal(randomizer io.Reader) netip.Prefix {
if randomizer == nil {
randomizer = rand.Reader
}
octets := []byte{169, 254, 0, 0}
_, _ = randomizer.Read(octets[2:4])
return coerceLinkLocal(octets)
}
func coerceLinkLocal(octets []byte) netip.Prefix {
if octets[3] == 0 {
octets[3] = 1 //please no .0 addresses
} else if octets[2] == 255 && octets[3] == 255 {
octets[3] = 254 //please no broadcast addresses
}
out, _ := netip.AddrFromSlice(octets)
return netip.PrefixFrom(out, 32)
}
// prepareUnsafeOriginAddr provides the IPv4 address used on IPv6-only clients that need to access IPv4 unsafe routes
func prepareUnsafeOriginAddr(d Device, l *logrus.Logger, c *config.C, routes []Route) netip.Prefix {
if !d.Networks()[0].Addr().Is6() {
return netip.Prefix{} //if we have an IPv4 assignment within the overlay, we don't need an unsafe origin address
}
needed := false
for _, route := range routes { //or if we have a route defined into an IPv4 range
if route.Cidr.Addr().Is4() {
needed = true //todo should this only apply to unsafe routes? almost certainly
break
}
}
if !needed {
return netip.Prefix{}
}
//todo better config name for sure
if a := c.GetString("tun.unsafe_origin_address_for_4over6", ""); a != "" {
out, err := netip.ParseAddr(a)
if err != nil {
l.WithField("value", a).WithError(err).Warn("failed to parse tun.unsafe_origin_address_for_4over6, will use a random value")
} else if !out.Is4() || !out.IsLinkLocalUnicast() {
l.WithField("value", out).Warn("tun.unsafe_origin_address_for_4over6 must be an IPv4 address")
} else if out.IsValid() {
return netip.PrefixFrom(out, 32)
}
}
return genLinkLocal(nil)
}
// prepareSnatAddr provides the address that an IPv6-only unsafe router should use to SNAT traffic before handing it to the operating system
func prepareSnatAddr(d Device, l *logrus.Logger, c *config.C) netip.Prefix {
if !d.Networks()[0].Addr().Is6() {
return netip.Prefix{} //if we have an IPv4 assignment within the overlay, we don't need a snat address
}
needed := false
for _, un := range d.UnsafeNetworks() { //if we are an unsafe router for an IPv4 range
if un.Addr().Is4() {
needed = true
break
}
}
if !needed {
return netip.Prefix{}
}
if a := c.GetString("tun.snat_address_for_4over6", ""); a != "" {
out, err := netip.ParseAddr(a)
if err != nil {
l.WithField("value", a).WithError(err).Warn("failed to parse tun.snat_address_for_4over6, will use a random value")
} else if !out.Is4() || !out.IsLinkLocalUnicast() {
l.WithField("value", out).Warn("tun.snat_address_for_4over6 must be an IPv4 address")
} else if out.IsValid() {
return netip.PrefixFrom(out, 32)
}
}
return genLinkLocal(nil)
}

View file

@ -19,14 +19,16 @@ import (
type tun struct {
io.ReadWriteCloser
fd int
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
fd int
vpnNetworks []netip.Prefix
unsafeNetworks []netip.Prefix
unsafeIPv4Origin netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@ -35,6 +37,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
ReadWriteCloser: file,
fd: deviceFd,
vpnNetworks: vpnNetworks,
unsafeNetworks: unsafeNetworks,
l: l,
}
@ -53,7 +56,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil
}
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android")
}
@ -76,6 +79,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
@ -91,6 +96,18 @@ func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) UnsafeNetworks() []netip.Prefix {
return t.unsafeNetworks
}
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
return t.unsafeIPv4Origin
}
func (t *tun) SNATAddress() netip.Prefix {
return netip.Prefix{}
}
func (t *tun) Name() string {
return "android"
}

View file

@ -24,13 +24,15 @@ import (
type tun struct {
io.ReadWriteCloser
Device string
vpnNetworks []netip.Prefix
DefaultMTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
Device string
vpnNetworks []netip.Prefix
unsafeNetworks []netip.Prefix
unsafeIPv4Origin netip.Prefix
DefaultMTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
@ -79,7 +81,7 @@ type ifreqAlias6 struct {
Lifetime addrLifetime
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "")
ifIndex := -1
if name != "" && name != "utun" {
@ -127,6 +129,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
Device: name,
vpnNetworks: vpnNetworks,
unsafeNetworks: unsafeNetworks,
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
@ -153,7 +156,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}
@ -213,6 +216,11 @@ func (t *tun) Activate() error {
}
}
}
if t.unsafeIPv4Origin.IsValid() && t.unsafeIPv4Origin.Addr().Is4() {
if err = t.activate4(t.unsafeIPv4Origin); err != nil {
return err
}
}
// Run the interface
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
@ -314,6 +322,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
if initial {
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
@ -545,6 +557,18 @@ func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) UnsafeNetworks() []netip.Prefix {
return t.unsafeNetworks
}
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
return t.unsafeIPv4Origin
}
func (t *tun) SNATAddress() netip.Prefix {
return netip.Prefix{}
}
func (t *tun) Name() string {
return t.Device
}

View file

@ -52,6 +52,17 @@ func (t *disabledTun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (*disabledTun) UnsafeNetworks() []netip.Prefix {
return nil
}
func (*disabledTun) SNATAddress() netip.Prefix {
return netip.Prefix{}
}
func (*disabledTun) UnsafeIPv4OriginAddress() netip.Prefix {
return netip.Prefix{}
}
func (*disabledTun) Name() string {
return "disabled"
}

View file

@ -86,14 +86,16 @@ type ifreqAlias6 struct {
}
type tun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
devFd int
Device string
vpnNetworks []netip.Prefix
unsafeNetworks []netip.Prefix
unsafeIPv4Origin netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
devFd int
}
func (t *tun) Read(to []byte) (int, error) {
@ -199,11 +201,11 @@ func (t *tun) Close() error {
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var fd int
var err error
@ -270,11 +272,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
t := &tun{
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
devFd: fd,
Device: deviceName,
vpnNetworks: vpnNetworks,
unsafeNetworks: unsafeNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
devFd: fd,
}
err = t.reload(c, true)
@ -410,6 +413,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
if initial {
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
@ -446,6 +453,18 @@ func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) UnsafeNetworks() []netip.Prefix {
return t.unsafeNetworks
}
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
return t.unsafeIPv4Origin
}
func (t *tun) SNATAddress() netip.Prefix {
return netip.Prefix{}
}
func (t *tun) Name() string {
return t.Device
}

View file

@ -22,20 +22,23 @@ import (
type tun struct {
io.ReadWriteCloser
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
vpnNetworks []netip.Prefix
unsafeNetworks []netip.Prefix
unsafeIPv4Origin netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
}
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{
vpnNetworks: vpnNetworks,
unsafeNetworks: unsafeNetworks,
ReadWriteCloser: &tunReadCloser{f: file},
l: l,
}
@ -69,6 +72,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
@ -147,6 +152,18 @@ func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) UnsafeNetworks() []netip.Prefix {
return t.unsafeNetworks
}
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
return t.unsafeIPv4Origin
}
func (t *tun) SNATAddress() netip.Prefix {
return netip.Prefix{}
}
func (t *tun) Name() string {
return "iOS"
}

View file

@ -26,14 +26,15 @@ import (
type tun struct {
io.ReadWriteCloser
fd int
Device string
vpnNetworks []netip.Prefix
MaxMTU int
DefaultMTU int
TXQueueLen int
deviceIndex int
ioctlFd uintptr
fd int
Device string
vpnNetworks []netip.Prefix
unsafeNetworks []netip.Prefix
MaxMTU int
DefaultMTU int
TXQueueLen int
deviceIndex int
ioctlFd uintptr
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
@ -46,6 +47,9 @@ type tun struct {
routesFromSystem map[netip.Prefix]routing.Gateways
routesFromSystemLock sync.Mutex
snatAddr netip.Prefix
unsafeIPv4Origin netip.Prefix
l *logrus.Logger
}
@ -53,6 +57,18 @@ func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) UnsafeNetworks() []netip.Prefix {
return t.unsafeNetworks
}
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
return t.unsafeIPv4Origin
}
func (t *tun) SNATAddress() netip.Prefix {
return t.snatAddr
}
type ifReq struct {
Name [16]byte
Flags uint16
@ -71,10 +87,10 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks)
if err != nil {
return nil, err
}
@ -84,7 +100,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@ -123,7 +139,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks)
if err != nil {
return nil, err
}
@ -133,11 +149,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
return t, nil
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
vpnNetworks: vpnNetworks,
unsafeNetworks: unsafeNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
@ -170,6 +187,11 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
if initial {
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) //todo MUST be different from t.snatAddr!
t.snatAddr = prepareSnatAddr(t, t.l, c)
}
routeTree, err := makeRouteTree(t.l, routes, true)
if err != nil {
return err
@ -313,6 +335,17 @@ func (t *tun) addIPs(link netlink.Link) error {
}
}
if t.unsafeIPv4Origin.IsValid() {
newAddrs = append(newAddrs, &netlink.Addr{
IPNet: &net.IPNet{
IP: t.unsafeIPv4Origin.Addr().AsSlice(),
Mask: net.CIDRMask(t.unsafeIPv4Origin.Bits(), t.unsafeIPv4Origin.Addr().BitLen()),
},
Label: t.unsafeIPv4Origin.Addr().Zone(),
})
t.l.WithField("address", t.unsafeIPv4Origin).Info("Adding origin address for IPv4 unsafe_routes")
}
//add all new addresses
for i := range newAddrs {
//AddrReplace still adds new IPs, but if their properties change it will change them as well
@ -400,7 +433,13 @@ func (t *tun) Activate() error {
//set route MTU
for i := range t.vpnNetworks {
if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
return fmt.Errorf("failed to set default route MTU: %w", err)
return fmt.Errorf("failed to set default route MTU for %s: %w", t.vpnNetworks[i], err)
}
}
if t.unsafeIPv4Origin.IsValid() {
if err = t.setDefaultRoute(t.unsafeIPv4Origin); err != nil {
return fmt.Errorf("failed to set default route MTU for %s: %w", t.unsafeIPv4Origin, err)
}
}
@ -427,6 +466,23 @@ func (t *tun) setMTU() {
}
}
func (t *tun) setSnatRoute() error {
dr := &net.IPNet{
IP: t.snatAddr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(t.snatAddr.Bits(), t.snatAddr.Addr().BitLen()),
}
nr := netlink.Route{
LinkIndex: t.deviceIndex,
Dst: dr,
Scope: unix.RT_SCOPE_LINK,
//Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST,
}
return netlink.RouteReplace(&nr)
}
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
dr := &net.IPNet{
IP: cidr.Masked().Addr().AsSlice(),
@ -503,6 +559,13 @@ func (t *tun) addRoutes(logErrors bool) error {
}
}
if t.snatAddr.IsValid() {
//at least for Linux, we need to set a return route for the SNATted traffic in order to satisfy the reverse-path filter,
//and to help the kernel deliver our reply traffic to the tun device.
//however, it is important that we do not actually /assign/ the SNAT address,
//since link-local addresses will not be routed between interfaces without significant trickery.
return t.setSnatRoute()
}
return nil
}

View file

@ -58,23 +58,25 @@ type addrLifetime struct {
}
type tun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
f *os.File
fd int
Device string
vpnNetworks []netip.Prefix
unsafeNetworks []netip.Prefix
unsafeIPv4Origin netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
f *os.File
fd int
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")
@ -350,6 +352,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
if initial {
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
@ -386,6 +392,18 @@ func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) UnsafeNetworks() []netip.Prefix {
return t.unsafeNetworks
}
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
return t.unsafeIPv4Origin
}
func (t *tun) SNATAddress() netip.Prefix {
return netip.Prefix{}
}
func (t *tun) Name() string {
return t.Device
}

View file

@ -49,25 +49,27 @@ type ifreq struct {
}
type tun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
f *os.File
fd int
Device string
vpnNetworks []netip.Prefix
unsafeNetworks []netip.Prefix
unsafeIPv4Origin netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
f *os.File
fd int
// cache out buffer since we need to prepend 4 bytes for tun metadata
out []byte
}
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")
@ -89,12 +91,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
t := &tun{
f: os.NewFile(uintptr(fd), ""),
fd: fd,
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
f: os.NewFile(uintptr(fd), ""),
fd: fd,
Device: deviceName,
vpnNetworks: vpnNetworks,
unsafeNetworks: unsafeNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
@ -270,6 +273,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
if initial {
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
@ -306,6 +313,18 @@ func (t *tun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *tun) UnsafeNetworks() []netip.Prefix {
return t.unsafeNetworks
}
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
return t.unsafeIPv4Origin
}
func (t *tun) SNATAddress() netip.Prefix {
return netip.Prefix{}
}
func (t *tun) Name() string {
return t.Device
}

179
overlay/tun_snat_test.go Normal file
View file

@ -0,0 +1,179 @@
package overlay
import (
"io"
"net/netip"
"testing"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockDevice is a minimal Device implementation for testing prepareUnsafeOriginAddr.
type mockDevice struct {
networks []netip.Prefix
unsafeNetworks []netip.Prefix
snatAddr netip.Prefix
unsafeSnatAddr netip.Prefix
}
func (d *mockDevice) Read([]byte) (int, error) { return 0, nil }
func (d *mockDevice) Write([]byte) (int, error) { return 0, nil }
func (d *mockDevice) Close() error { return nil }
func (d *mockDevice) Activate() error { return nil }
func (d *mockDevice) Networks() []netip.Prefix { return d.networks }
func (d *mockDevice) UnsafeNetworks() []netip.Prefix { return d.unsafeNetworks }
func (d *mockDevice) SNATAddress() netip.Prefix { return d.snatAddr }
func (d *mockDevice) UnsafeIPv4OriginAddress() netip.Prefix { return d.unsafeSnatAddr }
func (d *mockDevice) Name() string { return "mock" }
func (d *mockDevice) RoutesFor(netip.Addr) routing.Gateways { return routing.Gateways{} }
func (d *mockDevice) SupportsMultiqueue() bool { return false }
func (d *mockDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, nil }
func TestPrepareSnatAddr_V4Primary_NoSnat(t *testing.T) {
l := logrus.New()
l.SetLevel(logrus.PanicLevel)
c := config.NewC(l)
// If the device has an IPv4 primary address, no SNAT needed
d := &mockDevice{
networks: []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
}
result := prepareUnsafeOriginAddr(d, l, c, nil)
assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr when device has IPv4 primary")
}
func TestPrepareSnatAddr_V6Primary_NoUnsafeOrRoutes(t *testing.T) {
l := logrus.New()
l.SetLevel(logrus.PanicLevel)
c := config.NewC(l)
// IPv6 primary but no unsafe networks or IPv4 routes
d := &mockDevice{
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
}
result := prepareUnsafeOriginAddr(d, l, c, nil)
assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr without IPv4 unsafe networks or routes")
}
func TestPrepareSnatAddr_V6Primary_WithV4Unsafe(t *testing.T) {
l := logrus.New()
l.SetLevel(logrus.PanicLevel)
c := config.NewC(l)
// IPv6 primary with IPv4 unsafe network -> should get SNAT addr
d := &mockDevice{
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
}
result := prepareSnatAddr(d, l, c)
require.True(t, result.IsValid(), "should assign SNAT addr")
assert.True(t, result.Addr().Is4(), "SNAT addr should be IPv4")
assert.True(t, result.Addr().IsLinkLocalUnicast(), "SNAT addr should be link-local")
assert.Equal(t, 32, result.Bits(), "SNAT addr should be /32")
result = prepareUnsafeOriginAddr(d, l, c, nil)
require.False(t, result.IsValid(), "no routes = no origin addr needed")
}
func TestPrepareUnsafeOriginAddr_V6Primary_WithV4Route(t *testing.T) {
l := logrus.New()
l.SetLevel(logrus.PanicLevel)
c := config.NewC(l)
// IPv6 primary with IPv4 route -> should get SNAT addr
d := &mockDevice{
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
}
routes := []Route{
{Cidr: netip.MustParsePrefix("10.0.0.0/8")},
}
result := prepareUnsafeOriginAddr(d, l, c, routes)
require.True(t, result.IsValid(), "should assign SNAT addr when IPv4 route exists")
assert.True(t, result.Addr().Is4())
assert.True(t, result.Addr().IsLinkLocalUnicast())
result = prepareSnatAddr(d, l, c)
require.False(t, result.IsValid(), "no UnsafeNetworks = no snat addr needed")
}
func TestPrepareSnatAddr_V6Primary_V6UnsafeOnly(t *testing.T) {
l := logrus.New()
l.SetLevel(logrus.PanicLevel)
c := config.NewC(l)
// IPv6 primary with only IPv6 unsafe network -> no SNAT needed
d := &mockDevice{
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("fd01::/64")},
}
result := prepareUnsafeOriginAddr(d, l, c, nil)
assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr for IPv6-only unsafe networks")
}
func TestPrepareSnatAddr_ManualAddress(t *testing.T) {
l := logrus.New()
l.SetLevel(logrus.PanicLevel)
c := config.NewC(l)
c.Settings["tun"] = map[string]any{
"snat_address_for_4over6": "169.254.42.42",
}
d := &mockDevice{
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
}
result := prepareSnatAddr(d, l, c)
require.True(t, result.IsValid())
assert.Equal(t, netip.MustParseAddr("169.254.42.42"), result.Addr())
assert.Equal(t, 32, result.Bits())
}
func TestPrepareSnatAddr_InvalidManualAddress_Fallback(t *testing.T) {
l := logrus.New()
l.SetLevel(logrus.PanicLevel)
c := config.NewC(l)
c.Settings["tun"] = map[string]any{
"snat_address_for_4over6": "not-an-ip",
}
d := &mockDevice{
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
}
result := prepareSnatAddr(d, l, c)
// Should fall back to auto-assignment
require.True(t, result.IsValid(), "should fall back to auto-assigned address")
assert.True(t, result.Addr().Is4())
assert.True(t, result.Addr().IsLinkLocalUnicast())
}
func TestPrepareSnatAddr_AutoGenerated_Range(t *testing.T) {
l := logrus.New()
l.SetLevel(logrus.PanicLevel)
c := config.NewC(l)
d := &mockDevice{
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
}
// Generate several addresses and verify they're all in the expected range
for i := 0; i < 100; i++ {
result := prepareSnatAddr(d, l, c)
require.True(t, result.IsValid())
addr := result.Addr()
octets := addr.As4()
assert.Equal(t, byte(169), octets[0], "first octet should be 169")
assert.Equal(t, byte(254), octets[1], "second octet should be 254")
// Should not have .0 in the last octet
assert.NotEqual(t, byte(0), octets[3], "last octet should not be 0")
// Should not be 169.254.255.255 (broadcast)
if octets[2] == 255 {
assert.NotEqual(t, byte(255), octets[3], "should not be broadcast address")
}
}
}

37
overlay/tun_test.go Normal file
View file

@ -0,0 +1,37 @@
package overlay
import (
"bytes"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLinkLocal(t *testing.T) {
r := bytes.NewReader([]byte{42, 99})
result := genLinkLocal(r)
assert.Equal(t, netip.MustParsePrefix("169.254.42.99/32"), result, "genLinkLocal with a deterministic randomizer")
result = genLinkLocal(nil)
assert.True(t, result.IsValid(), "genLinkLocal with nil randomizer should be valid")
assert.True(t, result.Addr().IsLinkLocalUnicast(), "genLinkLocal with nil randomizer should be link-local")
result = coerceLinkLocal([]byte{169, 254, 100, 50})
assert.Equal(t, netip.MustParsePrefix("169.254.100.50/32"), result, "coerceLinkLocal should pass through normal values")
result = coerceLinkLocal([]byte{169, 254, 0, 0})
assert.Equal(t, netip.MustParsePrefix("169.254.0.1/32"), result, "coerceLinkLocal should bump .0 last octet to .1")
result = coerceLinkLocal([]byte{169, 254, 255, 255})
assert.Equal(t, netip.MustParsePrefix("169.254.255.254/32"), result, "coerceLinkLocal should bump broadcast 255.255 to 255.254")
result = coerceLinkLocal([]byte{169, 254, 0, 1})
assert.Equal(t, netip.MustParsePrefix("169.254.0.1/32"), result, "coerceLinkLocal should leave .1 last octet unchanged")
result = coerceLinkLocal([]byte{169, 254, 255, 254})
assert.Equal(t, netip.MustParsePrefix("169.254.255.254/32"), result, "coerceLinkLocal should leave 255.254 unchanged")
result = coerceLinkLocal([]byte{169, 254, 255, 100})
assert.Equal(t, netip.MustParsePrefix("169.254.255.100/32"), result, "coerceLinkLocal should leave 255.100 unchanged")
}

View file

@ -17,18 +17,21 @@ import (
)
type TestTun struct {
Device string
vpnNetworks []netip.Prefix
Routes []Route
routeTree *bart.Table[routing.Gateways]
l *logrus.Logger
Device string
vpnNetworks []netip.Prefix
unsafeNetworks []netip.Prefix
snatAddr netip.Prefix
unsafeIPv4Origin netip.Prefix
Routes []Route
routeTree *bart.Table[routing.Gateways]
l *logrus.Logger
closed atomic.Bool
rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
if err != nil {
return nil, err
@ -38,18 +41,22 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
return nil, err
}
return &TestTun{
Device: c.GetString("tun.dev", ""),
vpnNetworks: vpnNetworks,
Routes: routes,
routeTree: routeTree,
l: l,
rxPackets: make(chan []byte, 10),
TxPackets: make(chan []byte, 10),
}, nil
tt := &TestTun{
Device: c.GetString("tun.dev", ""),
vpnNetworks: vpnNetworks,
unsafeNetworks: unsafeNetworks,
Routes: routes,
routeTree: routeTree,
l: l,
rxPackets: make(chan []byte, 10),
TxPackets: make(chan []byte, 10),
}
tt.unsafeIPv4Origin = prepareUnsafeOriginAddr(tt, l, c, routes)
tt.snatAddr = prepareSnatAddr(tt, tt.l, c)
return tt, nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported")
}
@ -139,3 +146,15 @@ func (t *TestTun) SupportsMultiqueue() bool {
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented")
}
func (t *TestTun) UnsafeNetworks() []netip.Prefix {
return t.unsafeNetworks
}
func (t *TestTun) UnsafeIPv4OriginAddress() netip.Prefix {
return t.unsafeIPv4Origin
}
func (t *TestTun) SNATAddress() netip.Prefix {
return t.snatAddr
}

View file

@ -28,21 +28,23 @@ import (
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
type winTun struct {
Device string
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
Device string
vpnNetworks []netip.Prefix
unsafeNetworks []netip.Prefix
unsafeIPv4Origin netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
tun *wintun.NativeTun
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*winTun, error) {
err := checkWinTunExists()
if err != nil {
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
@ -55,10 +57,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}
t := &winTun{
Device: deviceName,
vpnNetworks: vpnNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
Device: deviceName,
vpnNetworks: vpnNetworks,
unsafeNetworks: unsafeNetworks,
MTU: c.GetInt("tun.mtu", DefaultMTU),
l: l,
}
err = t.reload(c, true)
@ -102,6 +105,10 @@ func (t *winTun) reload(c *config.C, initial bool) error {
return nil
}
if initial {
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
}
routeTree, err := makeRouteTree(t.l, routes, false)
if err != nil {
return err
@ -132,7 +139,12 @@ func (t *winTun) reload(c *config.C, initial bool) error {
func (t *winTun) Activate() error {
luid := winipcfg.LUID(t.tun.LUID())
err := luid.SetIPAddresses(t.vpnNetworks)
prefixes := t.vpnNetworks
if t.unsafeIPv4Origin.IsValid() {
prefixes = append(prefixes, t.unsafeIPv4Origin)
}
err := luid.SetIPAddresses(prefixes)
if err != nil {
return fmt.Errorf("failed to set address: %w", err)
}
@ -225,6 +237,18 @@ func (t *winTun) Networks() []netip.Prefix {
return t.vpnNetworks
}
func (t *winTun) UnsafeNetworks() []netip.Prefix {
return t.unsafeNetworks
}
func (t *winTun) UnsafeIPv4OriginAddress() netip.Prefix {
return t.unsafeIPv4Origin
}
func (t *winTun) SNATAddress() netip.Prefix {
return netip.Prefix{}
}
func (t *winTun) Name() string {
return t.Device
}

View file

@ -9,7 +9,7 @@ import (
"github.com/slackhq/nebula/routing"
)
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
return NewUserDevice(vpnNetworks)
}
@ -36,6 +36,17 @@ type UserDevice struct {
inboundWriter *io.PipeWriter
}
func (d *UserDevice) UnsafeNetworks() []netip.Prefix {
return nil
}
func (d *UserDevice) SNATAddress() netip.Prefix {
return netip.Prefix{}
}
func (d *UserDevice) UnsafeIPv4OriginAddress() netip.Prefix {
return netip.Prefix{}
}
func (d *UserDevice) Activate() error {
return nil
}

47
pki.go
View file

@ -91,7 +91,7 @@ func (p *PKI) reload(c *config.C, initial bool) error {
}
func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
newState, err := newCertStateFromConfig(c)
newState, err := newCertStateFromConfig(c, p.l)
if err != nil {
return util.NewContextualError("Could not load client cert", nil, err)
}
@ -102,7 +102,7 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
if currentState.v1Cert == nil {
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
} else {
// did IP in cert change? if so, don't set
// did IP in cert change? if so, don't set. If we ever allow this, need to set p.firewallReloadNeeded
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
@ -158,6 +158,14 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
}
}
newUN := newState.GetDefaultCertificate().UnsafeNetworks()
oldUN := currentState.GetDefaultCertificate().UnsafeNetworks()
if !slices.Equal(newUN, oldUN) {
//todo I don't love this, because other clients will see the new assignments and act on them, but we will not be able to.
//I think we need to wire this into the firewall reload.
p.l.WithFields(m{"previous": oldUN, "new": newUN}).Warning("UnsafeNetworks assignments differ. A restart is required in order for this to take effect.")
}
// Cipher cant be hot swapped so just leave it at what it was before
newState.cipher = currentState.cipher
@ -260,7 +268,7 @@ func (cs *CertState) MarshalJSON() ([]byte, error) {
return json.Marshal(msg)
}
func newCertStateFromConfig(c *config.C) (*CertState, error) {
func newCertStateFromConfig(c *config.C, l *logrus.Logger) (*CertState, error) {
var err error
privPathOrPEM := c.GetString("pki.key", "")
@ -344,10 +352,33 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion)
}
return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey)
return newCertState(l, initiatingVersion, v1, v2, isPkcs11, curve, rawKey)
}
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
func compareUnsafeNetworksAcrossCertVersions(v1, v2 cert.Certificate) error {
if v1 == nil || v2 == nil {
return nil //can't be a problem if we don't have one of the kinds of cert
}
v4UnsafeNets := 0
for _, n := range v2.UnsafeNetworks() {
if n.Addr().Is6() {
continue // V1 certs can't have IPv6 unsafe networks
} else {
v4UnsafeNets++
}
if !slices.Contains(v1.UnsafeNetworks(), n) {
return errors.New("UnsafeNetworks mismatch")
}
}
if len(v1.UnsafeNetworks()) != v4UnsafeNets {
return errors.New("UnsafeNetworks mismatch")
}
return nil
}
func newCertState(l *logrus.Logger, dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
cs := CertState{
privateKey: privateKey,
pkcs11Backed: pkcs11backed,
@ -370,6 +401,12 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
}
cs.initiatingVersion = dv
warn := compareUnsafeNetworksAcrossCertVersions(v1, v2)
if warn != nil {
l.WithFields(m{"UnsafeNetworksV1": v1.UnsafeNetworks(), "UnsafeNetworksV2": v2.UnsafeNetworks()}).
Warning("the IPv4 UnsafeNetworks assigned in the V1 certificate do not match the ones in V2")
}
}
if v1 != nil {

91
snat.go Normal file
View file

@ -0,0 +1,91 @@
package nebula
import (
"encoding/binary"
"net/netip"
)
func recalcIPv4Checksum(data []byte, oldSrcIP netip.Addr, newSrcIP netip.Addr) {
oldChecksum := binary.BigEndian.Uint16(data[10:12])
//because of how checksums work, we can re-use this function
checksum := calcNewTransportChecksum(oldChecksum, oldSrcIP, 0, newSrcIP, 0)
binary.BigEndian.PutUint16(data[10:12], checksum)
}
func calcNewTransportChecksum(oldChecksum uint16, oldSrcIP netip.Addr, oldSrcPort uint16, newSrcIP netip.Addr, newSrcPort uint16) uint16 {
oldIP := binary.BigEndian.Uint32(oldSrcIP.AsSlice())
newIP := binary.BigEndian.Uint32(newSrcIP.AsSlice())
// Start with inverted checksum
checksum := uint32(^oldChecksum)
// Subtract old IP (as two 16-bit words)
checksum += uint32(^uint16(oldIP >> 16))
checksum += uint32(^uint16(oldIP & 0xFFFF))
// Subtract old port
checksum += uint32(^oldSrcPort)
// Add new IP (as two 16-bit words)
checksum += uint32(newIP >> 16)
checksum += uint32(newIP & 0xFFFF)
// Add new port
checksum += uint32(newSrcPort)
// Fold carries
for checksum > 0xFFFF {
checksum = (checksum & 0xFFFF) + (checksum >> 16)
}
// Return ones' complement
return ^uint16(checksum)
}
func recalcV4TransportChecksum(offsetInsideHeader int, data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
ipHeaderOffset := int(data[0]&0x0F) * 4
offset := ipHeaderOffset + offsetInsideHeader
oldcsum := binary.BigEndian.Uint16(data[offset : offset+2])
checksum := calcNewTransportChecksum(oldcsum, oldSrcIP.Addr(), oldSrcIP.Port(), newSrcIP.Addr(), newSrcIP.Port())
binary.BigEndian.PutUint16(data[offset:offset+2], checksum)
}
func recalcUDPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
const offsetInsideHeader = 6
recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP)
}
func recalcTCPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
const offsetInsideHeader = 16
recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP)
}
func calcNewICMPChecksum(oldChecksum uint16, oldCode uint16, newCode uint16, oldID uint16, newID uint16) uint16 {
// Start with inverted checksum
checksum := uint32(^oldChecksum)
// Subtract old stuff
checksum += uint32(^oldCode)
checksum += uint32(^oldID)
// Add new stuff
checksum += uint32(newCode)
checksum += uint32(newID)
// Fold carries
for checksum > 0xFFFF {
checksum = (checksum & 0xFFFF) + (checksum >> 16)
}
// Return ones' complement
return ^uint16(checksum)
}
func recalcICMPv4Checksum(data []byte, oldCode uint16, newCode uint16, oldID uint16, newID uint16) {
const offsetInsideHeader = 2
ipHeaderOffset := int(data[0]&0x0F) * 4
offset := ipHeaderOffset + offsetInsideHeader
oldChecksum := binary.BigEndian.Uint16(data[offset : offset+2])
checksum := calcNewICMPChecksum(oldChecksum, oldCode, newCode, oldID, newID)
binary.BigEndian.PutUint16(data[offset:offset+2], checksum)
}

1310
snat_test.go Normal file

File diff suppressed because it is too large Load diff

View file

@ -10,6 +10,18 @@ import (
type NoopTun struct{}
func (NoopTun) UnsafeNetworks() []netip.Prefix {
return nil
}
func (NoopTun) SNATAddress() netip.Prefix {
return netip.Prefix{}
}
func (NoopTun) UnsafeIPv4OriginAddress() netip.Prefix {
return netip.Prefix{}
}
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
return routing.Gateways{}
}