mirror of
https://github.com/slackhq/nebula.git
synced 2026-03-13 02:02:56 -07:00
Merge 3e3bd9cead into 7760ccefba
This commit is contained in:
commit
143d230e2d
29 changed files with 2949 additions and 235 deletions
|
|
@ -439,11 +439,7 @@ func (c *certificateV2) validate() error {
|
|||
if !hasV6Networks {
|
||||
return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network)
|
||||
}
|
||||
} else if network.Addr().Is4() {
|
||||
if !hasV4Networks {
|
||||
return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
|
||||
}
|
||||
}
|
||||
} // as long as we have any IP address, IPv4 UnsafeNetworks are allowed
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
400
e2e/snat_test.go
Normal file
400
e2e/snat_test.go
Normal file
|
|
@ -0,0 +1,400 @@
|
|||
//go:build e2e_testing
|
||||
// +build e2e_testing
|
||||
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/e2e/router"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// parseIPv4UDPPacket extracts source/dest IPs, ports, and payload from an IPv4 UDP packet.
|
||||
func parseIPv4UDPPacket(t testing.TB, pkt []byte) (srcIP, dstIP netip.Addr, srcPort, dstPort uint16, payload []byte) {
|
||||
t.Helper()
|
||||
require.True(t, len(pkt) >= 28, "packet too short for IPv4+UDP header")
|
||||
require.Equal(t, byte(0x45), pkt[0]&0xF0|pkt[0]&0x0F, "not a simple IPv4 packet (IHL!=5)")
|
||||
|
||||
srcIP, _ = netip.AddrFromSlice(pkt[12:16])
|
||||
dstIP, _ = netip.AddrFromSlice(pkt[16:20])
|
||||
|
||||
ihl := int(pkt[0]&0x0F) * 4
|
||||
require.True(t, len(pkt) >= ihl+8, "packet too short for UDP header")
|
||||
srcPort = binary.BigEndian.Uint16(pkt[ihl : ihl+2])
|
||||
dstPort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4])
|
||||
udpLen := binary.BigEndian.Uint16(pkt[ihl+4 : ihl+6])
|
||||
payload = pkt[ihl+8 : ihl+int(udpLen)]
|
||||
return
|
||||
}
|
||||
|
||||
func TestSNAT_IPv6OnlyPeer_IPv4UnsafeTraffic(t *testing.T) {
|
||||
// Scenario: Two IPv6-only VPN nodes. The "router" node has unsafe networks
|
||||
// (192.168.0.0/16) in its cert and a configured SNAT address. The "sender"
|
||||
// node has an unsafe route for 192.168.0.0/16 via the router.
|
||||
//
|
||||
// When sender injects an IPv4 packet destined for the unsafe network, it
|
||||
// gets tunneled to the router. The router's firewall detects this is IPv4
|
||||
// from an IPv6-only peer and applies SNAT, rewriting the source IP to the
|
||||
// SNAT address before delivering it to TUN.
|
||||
//
|
||||
// When a reply comes back from TUN addressed to the SNAT address, the
|
||||
// router un-SNATs it (restoring the original destination) and tunnels it
|
||||
// back to the sender.
|
||||
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
unsafePrefix := "192.168.0.0/16"
|
||||
snatAddr := netip.MustParseAddr("169.254.42.42")
|
||||
|
||||
// Router: IPv6-only with unsafe networks and a manual SNAT address.
|
||||
// Override inbound firewall with local_cidr: "any" so both IPv4 (unsafe)
|
||||
// and IPv6 (VPN) traffic is accepted.
|
||||
routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(
|
||||
cert.Version2, ca, caKey, "router", "ff::1/64",
|
||||
netip.MustParseAddrPort("[beef::1]:4242"),
|
||||
unsafePrefix,
|
||||
m{
|
||||
"firewall": m{
|
||||
"inbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
"local_cidr": "any",
|
||||
}},
|
||||
},
|
||||
"tun": m{
|
||||
"snat_address_for_4over6": snatAddr.String(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
// Sender: IPv6-only with an unsafe route via the router
|
||||
senderControl, _, _, _ := newSimpleServerWithUdp(
|
||||
cert.Version2, ca, caKey, "sender", "ff::2/64",
|
||||
netip.MustParseAddrPort("[beef::2]:4242"),
|
||||
m{
|
||||
"tun": m{
|
||||
"unsafe_routes": []m{
|
||||
{"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
// Tell sender where the router lives
|
||||
senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr)
|
||||
|
||||
// Build the router and start both nodes
|
||||
r := router.NewR(t, routerControl, senderControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
routerControl.Start()
|
||||
senderControl.Start()
|
||||
|
||||
// --- Outbound: sender -> IPv4 unsafe dest (via router with SNAT) ---
|
||||
|
||||
origSrcIP := netip.MustParseAddr("10.0.0.1")
|
||||
unsafeDest := netip.MustParseAddr("192.168.1.1")
|
||||
var origSrcPort uint16 = 12345
|
||||
var dstPort uint16 = 80
|
||||
|
||||
t.Log("Sender injects an IPv4 packet to the unsafe network")
|
||||
senderControl.InjectTunUDPPacket(unsafeDest, dstPort, origSrcIP, origSrcPort, []byte("snat me"))
|
||||
|
||||
t.Log("Route packets (handshake + data) until the router gets the packet on TUN")
|
||||
snatPkt := r.RouteForAllUntilTxTun(routerControl)
|
||||
|
||||
t.Log("Verify the packet was SNATted")
|
||||
gotSrcIP, gotDstIP, gotSrcPort, gotDstPort, gotPayload := parseIPv4UDPPacket(t, snatPkt)
|
||||
assert.Equal(t, snatAddr, gotSrcIP, "source IP should be rewritten to the SNAT address")
|
||||
assert.Equal(t, unsafeDest, gotDstIP, "destination IP should be unchanged")
|
||||
assert.Equal(t, dstPort, gotDstPort, "destination port should be unchanged")
|
||||
assert.Equal(t, []byte("snat me"), gotPayload, "payload should be unchanged")
|
||||
|
||||
// Capture the SNAT port (may differ from original if port was remapped)
|
||||
snatPort := gotSrcPort
|
||||
t.Logf("SNAT port: %d (original: %d)", snatPort, origSrcPort)
|
||||
|
||||
// --- Return: reply from unsafe dest -> un-SNATted back to sender ---
|
||||
|
||||
t.Log("Router injects a reply packet from the unsafe dest to the SNAT address")
|
||||
routerControl.InjectTunUDPPacket(snatAddr, snatPort, unsafeDest, dstPort, []byte("reply from unsafe"))
|
||||
|
||||
t.Log("Route until sender gets the reply on TUN")
|
||||
replyPkt := r.RouteForAllUntilTxTun(senderControl)
|
||||
|
||||
t.Log("Verify the reply was un-SNATted")
|
||||
replySrcIP, replyDstIP, replySrcPort, replyDstPort, replyPayload := parseIPv4UDPPacket(t, replyPkt)
|
||||
assert.Equal(t, unsafeDest, replySrcIP, "reply source should be the unsafe dest")
|
||||
assert.Equal(t, origSrcIP, replyDstIP, "reply dest should be the original source IP (un-SNATted)")
|
||||
assert.Equal(t, dstPort, replySrcPort, "reply source port should be the unsafe dest port")
|
||||
assert.Equal(t, origSrcPort, replyDstPort, "reply dest port should be the original source port (un-SNATted)")
|
||||
assert.Equal(t, []byte("reply from unsafe"), replyPayload, "payload should be unchanged")
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", routerControl, senderControl)
|
||||
|
||||
// Also verify normal IPv6 VPN traffic still works between the nodes
|
||||
t.Log("Verify normal IPv6 VPN tunnel still works")
|
||||
assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r)
|
||||
|
||||
routerControl.Stop()
|
||||
senderControl.Stop()
|
||||
}
|
||||
|
||||
func TestSNAT_MultipleFlows(t *testing.T) {
|
||||
// Test that multiple distinct IPv4 flows from the same IPv6-only peer
|
||||
// are tracked independently through SNAT.
|
||||
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
unsafePrefix := "192.168.0.0/16"
|
||||
snatAddr := netip.MustParseAddr("169.254.42.42")
|
||||
|
||||
routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(
|
||||
cert.Version2, ca, caKey, "router", "ff::1/64",
|
||||
netip.MustParseAddrPort("[beef::1]:4242"),
|
||||
unsafePrefix,
|
||||
m{
|
||||
"firewall": m{
|
||||
"inbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
"local_cidr": "any",
|
||||
}},
|
||||
},
|
||||
"tun": m{
|
||||
"snat_address_for_4over6": snatAddr.String(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
senderControl, _, _, _ := newSimpleServerWithUdp(
|
||||
cert.Version2, ca, caKey, "sender", "ff::2/64",
|
||||
netip.MustParseAddrPort("[beef::2]:4242"),
|
||||
m{
|
||||
"tun": m{
|
||||
"unsafe_routes": []m{
|
||||
{"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr)
|
||||
|
||||
r := router.NewR(t, routerControl, senderControl)
|
||||
defer r.RenderFlow()
|
||||
r.CancelFlowLogs()
|
||||
|
||||
routerControl.Start()
|
||||
senderControl.Start()
|
||||
|
||||
unsafeDest := netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
// Send first flow
|
||||
senderControl.InjectTunUDPPacket(unsafeDest, 80, netip.MustParseAddr("10.0.0.1"), 1111, []byte("flow1"))
|
||||
pkt1 := r.RouteForAllUntilTxTun(routerControl)
|
||||
srcIP1, _, srcPort1, _, payload1 := parseIPv4UDPPacket(t, pkt1)
|
||||
assert.Equal(t, snatAddr, srcIP1)
|
||||
assert.Equal(t, []byte("flow1"), payload1)
|
||||
|
||||
// Send second flow (different source port)
|
||||
senderControl.InjectTunUDPPacket(unsafeDest, 80, netip.MustParseAddr("10.0.0.1"), 2222, []byte("flow2"))
|
||||
pkt2 := r.RouteForAllUntilTxTun(routerControl)
|
||||
srcIP2, _, srcPort2, _, payload2 := parseIPv4UDPPacket(t, pkt2)
|
||||
assert.Equal(t, snatAddr, srcIP2)
|
||||
assert.Equal(t, []byte("flow2"), payload2)
|
||||
|
||||
// The two flows should have different SNAT ports (since they're different conntracks)
|
||||
t.Logf("Flow 1 SNAT port: %d, Flow 2 SNAT port: %d", srcPort1, srcPort2)
|
||||
|
||||
// Reply to flow 2 first (out of order)
|
||||
routerControl.InjectTunUDPPacket(snatAddr, srcPort2, unsafeDest, 80, []byte("reply2"))
|
||||
reply2 := r.RouteForAllUntilTxTun(senderControl)
|
||||
_, replyDst2, _, replyDstPort2, replyPayload2 := parseIPv4UDPPacket(t, reply2)
|
||||
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), replyDst2)
|
||||
assert.Equal(t, uint16(2222), replyDstPort2, "reply to flow 2 should restore original port 2222")
|
||||
assert.Equal(t, []byte("reply2"), replyPayload2)
|
||||
|
||||
// Reply to flow 1
|
||||
routerControl.InjectTunUDPPacket(snatAddr, srcPort1, unsafeDest, 80, []byte("reply1"))
|
||||
reply1 := r.RouteForAllUntilTxTun(senderControl)
|
||||
_, replyDst1, _, replyDstPort1, replyPayload1 := parseIPv4UDPPacket(t, reply1)
|
||||
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), replyDst1)
|
||||
assert.Equal(t, uint16(1111), replyDstPort1, "reply to flow 1 should restore original port 1111")
|
||||
assert.Equal(t, []byte("reply1"), replyPayload1)
|
||||
|
||||
routerControl.Stop()
|
||||
senderControl.Stop()
|
||||
}
|
||||
|
||||
// --- Adversarial SNAT E2E Tests ---
|
||||
|
||||
func TestSNAT_UnsolicitedReplyDropped(t *testing.T) {
|
||||
// Without any outbound SNAT traffic, inject a packet from the router's TUN
|
||||
// addressed to the SNAT address. The sender must never receive it because
|
||||
// there's no conntrack entry to un-SNAT through.
|
||||
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
unsafePrefix := "192.168.0.0/16"
|
||||
snatAddr := netip.MustParseAddr("169.254.42.42")
|
||||
|
||||
routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(
|
||||
cert.Version2, ca, caKey, "router", "ff::1/64",
|
||||
netip.MustParseAddrPort("[beef::1]:4242"),
|
||||
unsafePrefix,
|
||||
m{
|
||||
"firewall": m{
|
||||
"inbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
"local_cidr": "any",
|
||||
}},
|
||||
},
|
||||
"tun": m{
|
||||
"snat_address_for_4over6": snatAddr.String(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
senderControl, _, _, _ := newSimpleServerWithUdp(
|
||||
cert.Version2, ca, caKey, "sender", "ff::2/64",
|
||||
netip.MustParseAddrPort("[beef::2]:4242"),
|
||||
m{
|
||||
"tun": m{
|
||||
"unsafe_routes": []m{
|
||||
{"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr)
|
||||
|
||||
r := router.NewR(t, routerControl, senderControl)
|
||||
defer r.RenderFlow()
|
||||
r.CancelFlowLogs()
|
||||
|
||||
routerControl.Start()
|
||||
senderControl.Start()
|
||||
|
||||
// First establish the tunnel with normal IPv6 traffic so handshake completes
|
||||
assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r)
|
||||
|
||||
// Inject the unsolicited reply from router's TUN to the SNAT address.
|
||||
// There is NO prior outbound SNAT flow, so no conntrack entry exists.
|
||||
// The router should silently drop this because unSnat finds no matching conntrack.
|
||||
routerControl.InjectTunUDPPacket(snatAddr, 55555, netip.MustParseAddr("192.168.1.1"), 80, []byte("unsolicited"))
|
||||
|
||||
// Send a canary IPv6 VPN packet after the bad one. Since the router processes
|
||||
// TUN packets sequentially, the canary arriving proves the bad packet was processed first.
|
||||
senderVpnAddr := senderControl.GetVpnAddrs()[0]
|
||||
routerControl.InjectTunUDPPacket(senderVpnAddr, 90, routerVpnIpNet[0].Addr(), 80, []byte("canary"))
|
||||
canaryPkt := r.RouteForAllUntilTxTun(senderControl)
|
||||
assertUdpPacket(t, []byte("canary"), canaryPkt, routerVpnIpNet[0].Addr(), senderVpnAddr, 80, 90)
|
||||
|
||||
// The unsolicited packet should have been dropped — nothing else on sender's TUN
|
||||
got := senderControl.GetFromTun(false)
|
||||
assert.Nil(t, got, "sender should not receive unsolicited packet to SNAT address with no conntrack entry")
|
||||
|
||||
routerControl.Stop()
|
||||
senderControl.Stop()
|
||||
}
|
||||
|
||||
func TestSNAT_NonUnsafeDestDropped(t *testing.T) {
|
||||
// An IPv6-only sender sends IPv4 traffic to a destination outside the router's
|
||||
// unsafe networks (172.16.0.1 when unsafe is 192.168.0.0/16). The router should
|
||||
// reject this because the local address is not routable. This verifies that
|
||||
// willingToHandleLocalAddr enforces boundaries on what SNAT traffic can reach.
|
||||
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
unsafePrefix := "192.168.0.0/16"
|
||||
snatAddr := netip.MustParseAddr("169.254.42.42")
|
||||
|
||||
routerControl, routerVpnIpNet, routerUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(
|
||||
cert.Version2, ca, caKey, "router", "ff::1/64",
|
||||
netip.MustParseAddrPort("[beef::1]:4242"),
|
||||
unsafePrefix,
|
||||
m{
|
||||
"firewall": m{
|
||||
"inbound": []m{{
|
||||
"proto": "any",
|
||||
"port": "any",
|
||||
"host": "any",
|
||||
"local_cidr": "any",
|
||||
}},
|
||||
},
|
||||
"tun": m{
|
||||
"snat_address_for_4over6": snatAddr.String(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
// Sender has unsafe routes for BOTH 192.168.0.0/16 AND 172.16.0.0/12 via router.
|
||||
// This means the sender will route 172.16.0.1 through the tunnel to the router.
|
||||
// But the router should reject it because 172.16.0.0/12 is NOT in its unsafe networks.
|
||||
senderControl, _, _, _ := newSimpleServerWithUdp(
|
||||
cert.Version2, ca, caKey, "sender", "ff::2/64",
|
||||
netip.MustParseAddrPort("[beef::2]:4242"),
|
||||
m{
|
||||
"tun": m{
|
||||
"unsafe_routes": []m{
|
||||
{"route": unsafePrefix, "via": routerVpnIpNet[0].Addr().String()},
|
||||
{"route": "172.16.0.0/12", "via": routerVpnIpNet[0].Addr().String()},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
senderControl.InjectLightHouseAddr(routerVpnIpNet[0].Addr(), routerUdpAddr)
|
||||
|
||||
r := router.NewR(t, routerControl, senderControl)
|
||||
defer r.RenderFlow()
|
||||
r.CancelFlowLogs()
|
||||
|
||||
routerControl.Start()
|
||||
senderControl.Start()
|
||||
|
||||
// Establish the tunnel first
|
||||
assertTunnel(t, routerVpnIpNet[0].Addr(), senderControl.GetVpnAddrs()[0], routerControl, senderControl, r)
|
||||
|
||||
// Send to 172.16.0.1 (NOT in router's unsafe networks 192.168.0.0/16).
|
||||
// The router should reject this at willingToHandleLocalAddr.
|
||||
senderControl.InjectTunUDPPacket(
|
||||
netip.MustParseAddr("172.16.0.1"), 80,
|
||||
netip.MustParseAddr("10.0.0.1"), 12345,
|
||||
[]byte("wrong dest"),
|
||||
)
|
||||
|
||||
// Send a canary to a valid unsafe destination to prove the bad packet was processed
|
||||
senderControl.InjectTunUDPPacket(
|
||||
netip.MustParseAddr("192.168.1.1"), 80,
|
||||
netip.MustParseAddr("10.0.0.1"), 33333,
|
||||
[]byte("canary"),
|
||||
)
|
||||
|
||||
// Route until the canary arrives — the 172.16.0.1 packet should have been
|
||||
// processed and dropped before the canary gets through
|
||||
canaryPkt := r.RouteForAllUntilTxTun(routerControl)
|
||||
_, canaryDst, _, _, canaryPayload := parseIPv4UDPPacket(t, canaryPkt)
|
||||
assert.Equal(t, netip.MustParseAddr("192.168.1.1"), canaryDst, "canary should arrive at the valid unsafe dest")
|
||||
assert.Equal(t, []byte("canary"), canaryPayload)
|
||||
|
||||
// No more packets — the 172.16.0.1 packet was dropped
|
||||
got := routerControl.GetFromTun(false)
|
||||
assert.Nil(t, got, "packet to non-unsafe destination 172.16.0.1 should be dropped by the router")
|
||||
|
||||
routerControl.Stop()
|
||||
senderControl.Stop()
|
||||
}
|
||||
311
firewall.go
311
firewall.go
|
|
@ -2,6 +2,7 @@ package nebula
|
|||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -22,10 +23,35 @@ import (
|
|||
"github.com/slackhq/nebula/firewall"
|
||||
)
|
||||
|
||||
var ErrCannotSNAT = errors.New("cannot SNAT this packet")
|
||||
var ErrSNATIdentityMismatch = errors.New("refusing to SNAT for mismatched host")
|
||||
var ErrSNATAddressCollision = errors.New("refusing to accept an incoming packet with my SNAT address")
|
||||
|
||||
const ipv4SourcePosition = 12
|
||||
const ipv4DestinationPosition = 16
|
||||
const sourcePortOffset = 0
|
||||
const destinationPortOffset = 2
|
||||
|
||||
type FirewallInterface interface {
|
||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
|
||||
}
|
||||
|
||||
type snatInfo struct {
|
||||
//Src is the source IP+port to write into unsafe-route-bound packet
|
||||
Src netip.AddrPort
|
||||
//SrcVpnIp is the overlay IP associated with this flow. It's needed to associate reply traffic so we can get it back to the right host.
|
||||
SrcVpnIp netip.Addr
|
||||
//SnatPort is the port to rewrite into an overlay-bound packet
|
||||
SnatPort uint16
|
||||
}
|
||||
|
||||
func (s *snatInfo) Valid() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
return s.Src.IsValid()
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
Expires time.Time // Time when this conntrack entry will expire
|
||||
|
||||
|
|
@ -34,6 +60,9 @@ type conn struct {
|
|||
// fields pack for free after the uint32 above
|
||||
incoming bool
|
||||
rulesVersion uint16
|
||||
|
||||
//for SNAT support
|
||||
snat *snatInfo
|
||||
}
|
||||
|
||||
// TODO: need conntrack max tracked connections handling
|
||||
|
|
@ -66,6 +95,8 @@ type Firewall struct {
|
|||
defaultLocalCIDRAny bool
|
||||
incomingMetrics firewallMetrics
|
||||
outgoingMetrics firewallMetrics
|
||||
unsafeIPv4Origin netip.Addr
|
||||
snatAddr netip.Addr
|
||||
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
|
@ -193,12 +224,12 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|||
|
||||
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
|
||||
certificate := cs.getCertificate(cert.Version2)
|
||||
if certificate == nil {
|
||||
if certificate == nil { //todo if config.initiating_version is set to 1, and unsafe_networks differ, things will suck
|
||||
certificate = cs.getCertificate(cert.Version1)
|
||||
}
|
||||
|
||||
if certificate == nil {
|
||||
panic("No certificate available to reconfigure the firewall")
|
||||
return nil, errors.New("no certificate available to reconfigure the firewall")
|
||||
}
|
||||
|
||||
fw := NewFirewall(
|
||||
|
|
@ -207,7 +238,6 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
|||
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
|
||||
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
|
||||
certificate,
|
||||
//TODO: max_connections
|
||||
)
|
||||
|
||||
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
|
||||
|
|
@ -314,6 +344,18 @@ func (f *Firewall) GetRuleHashes() string {
|
|||
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
|
||||
}
|
||||
|
||||
func (f *Firewall) SetSNATAddressFromInterface(i *Interface) {
|
||||
//address-mutation-avoidance is done inside Interface, the firewall doesn't need to care
|
||||
//todo should snatted conntracks get expired out? Probably not needed until if/when we allow reload
|
||||
f.snatAddr = i.inside.SNATAddress().Addr()
|
||||
f.unsafeIPv4Origin = i.inside.UnsafeIPv4OriginAddress().Addr()
|
||||
}
|
||||
|
||||
func (f *Firewall) ShouldUnSNAT(fp *firewall.Packet) bool {
|
||||
// f.snatAddr is only valid if we're a snat-capable router
|
||||
return f.snatAddr.IsValid() && fp.RemoteAddr == f.snatAddr
|
||||
}
|
||||
|
||||
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
|
||||
var table string
|
||||
if inbound {
|
||||
|
|
@ -414,50 +456,207 @@ var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate
|
|||
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
|
||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||
|
||||
func (f *Firewall) unSnat(data []byte, fp *firewall.Packet) netip.Addr {
|
||||
c := f.peek(*fp) //unfortunately this needs to lock. Surely there's a better way.
|
||||
if c == nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
if !c.snat.Valid() {
|
||||
return netip.Addr{}
|
||||
}
|
||||
oldIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort)
|
||||
rewritePacket(data, fp, oldIP, c.snat.Src, ipv4DestinationPosition, destinationPortOffset)
|
||||
return c.snat.SrcVpnIp
|
||||
}
|
||||
|
||||
func rewritePacket(data []byte, fp *firewall.Packet, oldIP netip.AddrPort, newIP netip.AddrPort, ipOffset int, portOffset int) {
|
||||
//change address
|
||||
copy(data[ipOffset:], newIP.Addr().AsSlice())
|
||||
recalcIPv4Checksum(data, oldIP.Addr(), newIP.Addr())
|
||||
ipHeaderLen := int(data[0]&0x0F) * 4
|
||||
|
||||
switch fp.Protocol {
|
||||
case firewall.ProtoICMP:
|
||||
binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], newIP.Port()) //we use the ID field as a "port" for ICMP
|
||||
icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on code yet (but Linux would)
|
||||
recalcICMPv4Checksum(data, icmpCode, icmpCode, oldIP.Port(), newIP.Port())
|
||||
case firewall.ProtoUDP:
|
||||
dstport := ipHeaderLen + portOffset
|
||||
binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port())
|
||||
recalcUDPv4Checksum(data, oldIP, newIP)
|
||||
case firewall.ProtoTCP:
|
||||
dstport := ipHeaderLen + portOffset
|
||||
binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port())
|
||||
recalcTCPv4Checksum(data, oldIP, newIP)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Firewall) findUsableSNATPort(fp *firewall.Packet, c *conn) error {
|
||||
const halfThePorts = 0x7fff
|
||||
oldPort := fp.RemotePort
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
defer conntrack.Unlock()
|
||||
for numPortsChecked := 0; numPortsChecked < halfThePorts; numPortsChecked++ {
|
||||
_, ok := conntrack.Conns[*fp]
|
||||
if !ok {
|
||||
//yay, we can use this port
|
||||
//track the snatted flow with the same expiration as the unsnatted version
|
||||
c.snat.SnatPort = fp.RemotePort
|
||||
conntrack.Conns[*fp] = c
|
||||
return nil
|
||||
}
|
||||
//increment and retry. There's probably better strategies out there
|
||||
fp.RemotePort++
|
||||
if fp.RemotePort < halfThePorts {
|
||||
fp.RemotePort += halfThePorts // keep it ephemeral for now
|
||||
}
|
||||
}
|
||||
|
||||
//if we made it here, we failed
|
||||
fp.RemotePort = oldPort
|
||||
return ErrCannotSNAT
|
||||
}
|
||||
|
||||
func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) error {
|
||||
if !f.snatAddr.IsValid() {
|
||||
return ErrCannotSNAT
|
||||
}
|
||||
if f.snatAddr == fp.LocalAddr { //a packet that came from UDP (incoming) should never ever have our snat address on it
|
||||
return ErrSNATAddressCollision
|
||||
}
|
||||
if c.snat.Valid() {
|
||||
//old flow: make sure it came from the right place
|
||||
if !slices.Contains(hostinfo.vpnAddrs, c.snat.SrcVpnIp) {
|
||||
return ErrSNATIdentityMismatch
|
||||
}
|
||||
fp.RemoteAddr = f.snatAddr
|
||||
fp.RemotePort = c.snat.SnatPort
|
||||
} else if hostinfo.vpnAddrs[0].Is6() {
|
||||
//we got a new flow
|
||||
c.snat = &snatInfo{
|
||||
Src: netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort),
|
||||
SrcVpnIp: hostinfo.vpnAddrs[0],
|
||||
}
|
||||
fp.RemoteAddr = f.snatAddr
|
||||
//find a new port to use, if needed
|
||||
err := f.findUsableSNATPort(fp, c)
|
||||
if err != nil {
|
||||
c.snat = nil
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return ErrCannotSNAT
|
||||
}
|
||||
|
||||
newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort)
|
||||
rewritePacket(data, fp, c.snat.Src, newIP, ipv4SourcePosition, sourcePortOffset)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firewall) identifyRemoteNetworkType(h *HostInfo, fp firewall.Packet) NetworkType {
|
||||
if h.networks == nil {
|
||||
// Simple case: Certificate has one address and no unsafe networks
|
||||
if h.vpnAddrs[0] == fp.RemoteAddr {
|
||||
return NetworkTypeVPN
|
||||
} //else, fallthrough
|
||||
} else if nwType, ok := h.networks.Lookup(fp.RemoteAddr); ok {
|
||||
return nwType //will return NetworkTypeVPN or NetworkTypeUnsafe
|
||||
}
|
||||
|
||||
//RemoteAddr not in our networks table
|
||||
if f.snatAddr.IsValid() && fp.IsIPv4() && h.HasOnlyV6Addresses() {
|
||||
return NetworkTypeUnverifiedSNATPeer
|
||||
} else {
|
||||
return NetworkTypeInvalidPeer
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Firewall) allowRemoteNetworkType(nwType NetworkType, fp firewall.Packet) error {
|
||||
switch nwType {
|
||||
case NetworkTypeVPN:
|
||||
return nil
|
||||
case NetworkTypeInvalidPeer:
|
||||
return ErrInvalidRemoteIP
|
||||
case NetworkTypeVPNPeer:
|
||||
//one day we might need a specialSnatMode case in here to handle routers with v4 addresses when we don't also have a v4 address?
|
||||
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
||||
case NetworkTypeUnsafe:
|
||||
return nil // nothing special, one day this may have different FW rules
|
||||
case NetworkTypeUnverifiedSNATPeer:
|
||||
if f.unsafeIPv4Origin.IsValid() && fp.LocalAddr == f.unsafeIPv4Origin {
|
||||
return nil //the client case
|
||||
}
|
||||
if f.snatAddr.IsValid() {
|
||||
if fp.RemoteAddr == f.snatAddr {
|
||||
return ErrInvalidRemoteIP //we should never get a packet with our SNAT addr as the destination, or "from" our SNAT addr
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
default:
|
||||
return ErrUnknownNetworkType //should never happen
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Firewall) willingToHandleLocalAddr(incoming bool, fp firewall.Packet, remoteNwType NetworkType) error {
|
||||
if f.routableNetworks.Contains(fp.LocalAddr) {
|
||||
return nil //easy, this should handle NetworkTypeVPN in all cases, and NetworkTypeUnsafe on the router side
|
||||
}
|
||||
if incoming { //at least for now, reject all traffic other than what we've already decided is locally routable
|
||||
return ErrInvalidLocalIP
|
||||
}
|
||||
|
||||
//below this line, all traffic is outgoing. Outgoing traffic to NetworkTypeUnsafe is not required to be considered inbound-routable
|
||||
if remoteNwType == NetworkTypeUnsafe {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrInvalidLocalIP
|
||||
}
|
||||
|
||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||
// returns nil if the packet should not be dropped.
|
||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
||||
func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
||||
table := f.OutRules
|
||||
if incoming {
|
||||
table = f.InRules
|
||||
}
|
||||
|
||||
snatmode := fp.IsIPv4() && h.HasOnlyV6Addresses() && f.snatAddr.IsValid()
|
||||
if snatmode {
|
||||
//if this is an IPv4 packet from a V6 only host, and we're configured to snat that kind of traffic, it must be snatted,
|
||||
//so it can never be in the localcache, which lacks SNAT data
|
||||
//nil out the pointer to avoid ever using it
|
||||
localCache = nil
|
||||
}
|
||||
|
||||
// Check if we spoke to this tuple, if we did then allow this packet
|
||||
if f.inConns(fp, h, caPool, localCache) {
|
||||
if localCache != nil {
|
||||
if _, ok := localCache[fp]; ok {
|
||||
return nil //packet matched the cache, we're not snatting, we can return early
|
||||
}
|
||||
}
|
||||
c := f.inConns(fp, h, caPool, localCache)
|
||||
if c != nil {
|
||||
if incoming && snatmode {
|
||||
return f.applySnat(pkt, &fp, c, h)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make sure remote address matches nebula certificate, and determine how to treat it
|
||||
if h.networks == nil {
|
||||
// Simple case: Certificate has one address and no unsafe networks
|
||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
} else {
|
||||
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
switch nwType {
|
||||
case NetworkTypeVPN:
|
||||
break // nothing special
|
||||
case NetworkTypeVPNPeer:
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
||||
case NetworkTypeUnsafe:
|
||||
break // nothing special, one day this may have different FW rules
|
||||
default:
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrUnknownNetworkType //should never happen
|
||||
}
|
||||
remoteNetworkType := f.identifyRemoteNetworkType(h, fp)
|
||||
if err := f.allowRemoteNetworkType(remoteNetworkType, fp); err != nil {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure we are supposed to be handling this local ip address
|
||||
if !f.routableNetworks.Contains(fp.LocalAddr) {
|
||||
if err := f.willingToHandleLocalAddr(incoming, fp, remoteNetworkType); err != nil {
|
||||
f.metrics(incoming).droppedLocalAddr.Inc(1)
|
||||
return ErrInvalidLocalIP
|
||||
}
|
||||
|
||||
table := f.OutRules
|
||||
if incoming {
|
||||
table = f.InRules
|
||||
return err
|
||||
}
|
||||
|
||||
// We now know which firewall table to check against
|
||||
|
|
@ -467,9 +666,14 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||
}
|
||||
|
||||
// We always want to conntrack since it is a faster operation
|
||||
f.addConn(fp, incoming)
|
||||
c = f.addConn(fp, incoming)
|
||||
|
||||
return nil
|
||||
if incoming && remoteNetworkType == NetworkTypeUnverifiedSNATPeer {
|
||||
return f.applySnat(pkt, &fp, c, h)
|
||||
} else {
|
||||
//outgoing snat is handled before this function is called
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Firewall) metrics(incoming bool) firewallMetrics {
|
||||
|
|
@ -496,12 +700,14 @@ func (f *Firewall) EmitStats() {
|
|||
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
||||
}
|
||||
|
||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
|
||||
if localCache != nil {
|
||||
if _, ok := localCache[fp]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
func (f *Firewall) peek(fp firewall.Packet) *conn {
|
||||
f.Conntrack.Lock()
|
||||
c := f.Conntrack.Conns[fp]
|
||||
f.Conntrack.Unlock()
|
||||
return c
|
||||
}
|
||||
|
||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn {
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
|
||||
|
|
@ -515,7 +721,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||
|
||||
if !ok {
|
||||
conntrack.Unlock()
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.rulesVersion != f.rulesVersion {
|
||||
|
|
@ -538,7 +744,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||
}
|
||||
delete(conntrack.Conns, fp)
|
||||
conntrack.Unlock()
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
|
|
@ -568,12 +774,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||
localCache[fp] = struct{}{}
|
||||
}
|
||||
|
||||
return true
|
||||
return c
|
||||
}
|
||||
|
||||
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
||||
func (f *Firewall) packetTimeout(fp firewall.Packet) time.Duration {
|
||||
var timeout time.Duration
|
||||
c := &conn{}
|
||||
|
||||
switch fp.Protocol {
|
||||
case firewall.ProtoTCP:
|
||||
|
|
@ -583,7 +788,13 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
|||
default:
|
||||
timeout = f.DefaultTimeout
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) *conn {
|
||||
c := &conn{}
|
||||
|
||||
timeout := f.packetTimeout(fp)
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
if _, ok := conntrack.Conns[fp]; !ok {
|
||||
|
|
@ -597,7 +808,9 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
|||
c.rulesVersion = f.rulesVersion
|
||||
c.Expires = time.Now().Add(timeout)
|
||||
conntrack.Conns[fp] = c
|
||||
|
||||
conntrack.Unlock()
|
||||
return c
|
||||
}
|
||||
|
||||
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
||||
|
|
|
|||
|
|
@ -31,6 +31,10 @@ type Packet struct {
|
|||
Fragment bool
|
||||
}
|
||||
|
||||
func (fp *Packet) IsIPv4() bool {
|
||||
return fp.LocalAddr.Is4() && fp.RemoteAddr.Is4()
|
||||
}
|
||||
|
||||
func (fp *Packet) Copy() *Packet {
|
||||
return &Packet{
|
||||
LocalAddr: fp.LocalAddr,
|
||||
|
|
|
|||
186
firewall_test.go
186
firewall_test.go
|
|
@ -213,44 +213,44 @@ func TestFirewall_Drop(t *testing.T) {
|
|||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
|
||||
// test remote mismatch
|
||||
oldRemote := p.RemoteAddr
|
||||
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
p.RemoteAddr = oldRemote
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_DropV6(t *testing.T) {
|
||||
|
|
@ -292,44 +292,44 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
|
||||
// test remote mismatch
|
||||
oldRemote := p.RemoteAddr
|
||||
p.RemoteAddr = netip.MustParseAddr("fd12::56")
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
p.RemoteAddr = oldRemote
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
|
|
@ -537,10 +537,10 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||
cp := cert.NewCAPool()
|
||||
|
||||
// h1/c1 lacks the proper groups
|
||||
require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
|
||||
require.ErrorIs(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrNoMatchingRule)
|
||||
// c has the proper groups
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop3(t *testing.T) {
|
||||
|
|
@ -618,18 +618,18 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||
cp := cert.NewCAPool()
|
||||
|
||||
// c1 should pass because host match
|
||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
|
||||
// c2 should pass because ca sha match
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h2, cp, nil))
|
||||
// c3 should fail because no match
|
||||
resetConntrack(fw)
|
||||
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// Test a remote address match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop3V6(t *testing.T) {
|
||||
|
|
@ -667,7 +667,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
cp := cert.NewCAPool()
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
|
|
@ -709,12 +709,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
|
||||
oldFw := fw
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
|
|
@ -723,7 +723,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
// Allow outbound because conntrack and new rules allow port 10
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
|
||||
oldFw = fw
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
|
|
@ -732,7 +732,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
// Drop outbound because conntrack doesn't match new ruleset
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
}
|
||||
|
||||
func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||
|
|
@ -778,12 +778,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||
p.LocalPort = 0
|
||||
p.RemotePort = 0
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
|
||||
//now also allow outbound
|
||||
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
|
||||
})
|
||||
|
||||
t.Run("nonzero ports", func(t *testing.T) {
|
||||
|
|
@ -791,12 +791,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||
p.LocalPort = 0xabcd
|
||||
p.RemotePort = 0x1234
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
|
||||
//now also allow outbound
|
||||
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
|
||||
})
|
||||
})
|
||||
|
||||
|
|
@ -808,12 +808,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||
p.LocalPort = 0
|
||||
p.RemotePort = 0
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
//now also allow outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
})
|
||||
|
||||
t.Run("nonzero ports, still blocked", func(t *testing.T) {
|
||||
|
|
@ -821,12 +821,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||
p.LocalPort = 0xabcd
|
||||
p.RemotePort = 0x1234
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
//now also allow outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
})
|
||||
|
||||
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
|
||||
|
|
@ -834,12 +834,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||
p.LocalPort = 80
|
||||
p.RemotePort = 80
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
//now also allow outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
})
|
||||
})
|
||||
t.Run("Any proto, any port", func(t *testing.T) {
|
||||
|
|
@ -851,12 +851,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||
p.LocalPort = 0
|
||||
p.RemotePort = 0
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
|
||||
//now also allow outbound
|
||||
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
|
||||
})
|
||||
|
||||
t.Run("nonzero ports, allowed", func(t *testing.T) {
|
||||
|
|
@ -865,15 +865,15 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||
p.LocalPort = 0xabcd
|
||||
p.RemotePort = 0x1234
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
|
||||
//now also allow outbound
|
||||
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
|
||||
//different ID is blocked
|
||||
p.RemotePort++
|
||||
require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
require.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
})
|
||||
})
|
||||
|
||||
|
|
@ -922,7 +922,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrInvalidRemoteIP)
|
||||
}
|
||||
|
||||
func BenchmarkLookup(b *testing.B) {
|
||||
|
|
@ -1042,7 +1042,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
|||
l := test.NewLogger()
|
||||
// Test a bad rule definition
|
||||
c := &dummyCert{}
|
||||
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
|
||||
cs, err := newCertState(l, cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conf := config.NewC(l)
|
||||
|
|
@ -1336,7 +1336,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
|
|||
t.Helper()
|
||||
cp := cert.NewCAPool()
|
||||
resetConntrack(fw)
|
||||
err := fw.Drop(c.p, true, c.h, cp, nil)
|
||||
err := fw.Drop(c.p, nil, true, c.h, cp, nil)
|
||||
if c.err == nil {
|
||||
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
|
||||
} else {
|
||||
|
|
@ -1344,7 +1344,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
|
|||
}
|
||||
}
|
||||
|
||||
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
|
||||
func buildHostinfo(setup testsetup, theirPrefixes ...netip.Prefix) *HostInfo {
|
||||
c1 := dummyCert{
|
||||
name: "host1",
|
||||
networks: theirPrefixes,
|
||||
|
|
@ -1364,6 +1364,11 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te
|
|||
h.vpnAddrs[i] = theirPrefixes[i].Addr()
|
||||
}
|
||||
h.buildNetworks(setup.myVpnNetworksTable, &c1)
|
||||
return &h
|
||||
}
|
||||
|
||||
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
|
||||
h := buildHostinfo(setup, theirPrefixes...)
|
||||
p := firewall.Packet{
|
||||
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
|
||||
RemoteAddr: theirPrefixes[0].Addr(),
|
||||
|
|
@ -1373,9 +1378,9 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te
|
|||
Fragment: false,
|
||||
}
|
||||
return testcase{
|
||||
h: &h,
|
||||
h: h,
|
||||
p: p,
|
||||
c: &c1,
|
||||
c: h.ConnectionState.peerCert.Certificate,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
|
@ -1397,6 +1402,19 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse
|
|||
return newSetupFromCert(t, l, c)
|
||||
}
|
||||
|
||||
func newSnatSetup(t *testing.T, l *logrus.Logger, myPrefix netip.Prefix, snatAddr netip.Addr) testsetup {
|
||||
c := dummyCert{
|
||||
name: "me",
|
||||
networks: []netip.Prefix{myPrefix},
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
|
||||
out := newSetupFromCert(t, l, c)
|
||||
out.fw.snatAddr = snatAddr
|
||||
return out
|
||||
}
|
||||
|
||||
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
for _, prefix := range c.Networks() {
|
||||
|
|
@ -1532,3 +1550,59 @@ func resetConntrack(fw *Firewall) {
|
|||
fw.Conntrack.Conns = map[firewall.Packet]*conn{}
|
||||
fw.Conntrack.Unlock()
|
||||
}
|
||||
|
||||
func TestFirewall_SNAT(t *testing.T) {
|
||||
t.Parallel()
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
cp := cert.NewCAPool()
|
||||
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
||||
|
||||
MyCert := dummyCert{
|
||||
name: "me",
|
||||
networks: []netip.Prefix{myPrefix},
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
|
||||
theirPrefix := netip.MustParsePrefix("1.2.2.2/8")
|
||||
snatAddr := netip.MustParseAddr("169.254.55.96")
|
||||
t.Run("allow inbound all matching", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
myCert := MyCert.Copy()
|
||||
setup := newSnatSetup(t, l, myPrefix, snatAddr)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||
resetConntrack(setup.fw)
|
||||
h := buildHostinfo(setup, theirPrefix)
|
||||
p := firewall.Packet{
|
||||
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
|
||||
RemoteAddr: h.vpnAddrs[0],
|
||||
LocalPort: 10,
|
||||
RemotePort: 90,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
require.NoError(t, setup.fw.Drop(p, nil, true, h, cp, nil))
|
||||
})
|
||||
//t.Run("allow inbound unsafe route", func(t *testing.T) {
|
||||
// t.Parallel()
|
||||
// unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
|
||||
// c := dummyCert{
|
||||
// name: "me",
|
||||
// networks: []netip.Prefix{myPrefix},
|
||||
// unsafeNetworks: []netip.Prefix{unsafePrefix},
|
||||
// groups: []string{"default-group"},
|
||||
// issuer: "signer-shasum",
|
||||
// }
|
||||
// unsafeSetup := newSetupFromCert(t, l, c)
|
||||
// tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
|
||||
// tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
|
||||
// tc.err = ErrNoMatchingRule
|
||||
// tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
|
||||
// require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
|
||||
// tc.err = nil
|
||||
// tc.Test(t, unsafeSetup.fw) //should pass
|
||||
//})
|
||||
}
|
||||
|
|
|
|||
12
hostmap.go
12
hostmap.go
|
|
@ -224,6 +224,9 @@ const (
|
|||
NetworkTypeVPNPeer
|
||||
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
|
||||
NetworkTypeUnsafe
|
||||
// NetworkTypeUnverifiedSNATPeer is used to indicate traffic we're willing to route, but never deliver to a NetworkTypeVPN
|
||||
NetworkTypeUnverifiedSNATPeer
|
||||
NetworkTypeInvalidPeer
|
||||
)
|
||||
|
||||
type HostInfo struct {
|
||||
|
|
@ -277,6 +280,15 @@ type HostInfo struct {
|
|||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (i *HostInfo) HasOnlyV6Addresses() bool {
|
||||
for _, vpnIp := range i.vpnAddrs {
|
||||
if !vpnIp.Is6() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type ViaSender struct {
|
||||
UdpAddr netip.AddrPort
|
||||
relayHI *HostInfo // relayHI is the host info object of the relay
|
||||
|
|
|
|||
29
inside.go
29
inside.go
|
|
@ -48,9 +48,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||
return
|
||||
}
|
||||
|
||||
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
hostinfo, ready := f.getHostinfo(packet, fwPacket)
|
||||
|
||||
if hostinfo == nil {
|
||||
f.rejectInside(packet, out, q)
|
||||
|
|
@ -66,10 +64,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||
return
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
dropReason := f.firewall.Drop(*fwPacket, packet, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason == nil {
|
||||
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||
|
||||
} else {
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
|
|
@ -81,6 +78,26 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
|||
}
|
||||
}
|
||||
|
||||
func (f *Interface) getHostinfo(packet []byte, fwPacket *firewall.Packet) (*HostInfo, bool) {
|
||||
if f.firewall.ShouldUnSNAT(fwPacket) {
|
||||
//unsnat packet re-writing also happens here, would be nice to not,
|
||||
//but we need to do the unsnat lookup to find the hostinfo so we can run the firewall checks
|
||||
destVpnAddr := f.firewall.unSnat(packet, fwPacket)
|
||||
if destVpnAddr.IsValid() {
|
||||
//because this was a snatted packet, we know it has an on-overlay destination, so no routing should be required.
|
||||
return f.getOrHandshakeNoRouting(destVpnAddr, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
} else {
|
||||
return nil, false
|
||||
}
|
||||
} else { //if we didn't need to unsnat
|
||||
return f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
||||
if !f.firewall.InSendReject {
|
||||
return
|
||||
|
|
@ -218,7 +235,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||
}
|
||||
|
||||
// check if packet is in outbound fw rules
|
||||
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||
dropReason := f.firewall.Drop(*fp, p, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||
if dropReason != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("fwPacket", fp).
|
||||
|
|
|
|||
|
|
@ -248,6 +248,7 @@ func (f *Interface) activate() {
|
|||
f.inside.Close()
|
||||
f.l.Fatal(err)
|
||||
}
|
||||
f.firewall.SetSNATAddressFromInterface(f)
|
||||
}
|
||||
|
||||
func (f *Interface) run() {
|
||||
|
|
@ -344,6 +345,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||
f.l.WithError(err).Error("Error while creating firewall during reload")
|
||||
return
|
||||
}
|
||||
fw.SetSNATAddressFromInterface(f)
|
||||
|
||||
oldFw := f.firewall
|
||||
conntrack := oldFw.Conntrack
|
||||
|
|
|
|||
3
main.go
3
main.go
|
|
@ -131,7 +131,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
deviceFactory = overlay.NewDeviceFromConfig
|
||||
}
|
||||
|
||||
tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
|
||||
cs := pki.getCertState()
|
||||
tun, err = deviceFactory(c, l, cs.myVpnNetworks, cs.GetDefaultCertificate().UnsafeNetworks(), routines)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -514,7 +514,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||
return false
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
dropReason := f.firewall.Drop(*fwPacket, out, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason != nil {
|
||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||
// This gives us a buffer to build the reject packet in
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ type Device interface {
|
|||
io.ReadWriteCloser
|
||||
Activate() error
|
||||
Networks() []netip.Prefix
|
||||
UnsafeNetworks() []netip.Prefix
|
||||
UnsafeIPv4OriginAddress() netip.Prefix
|
||||
SNATAddress() netip.Prefix
|
||||
Name() string
|
||||
RoutesFor(netip.Addr) routing.Gateways
|
||||
SupportsMultiqueue() bool
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
package overlay
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
|
|
@ -22,22 +24,22 @@ func (e *NameError) Error() string {
|
|||
}
|
||||
|
||||
// TODO: We may be able to remove routines
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error)
|
||||
|
||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
switch {
|
||||
case c.GetBool("tun.disabled", false):
|
||||
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||
return tun, nil
|
||||
|
||||
default:
|
||||
return newTun(c, l, vpnNetworks, routines > 1)
|
||||
return newTun(c, l, vpnNetworks, unsafeNetworks, routines > 1)
|
||||
}
|
||||
}
|
||||
|
||||
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return newTunFromFd(c, l, *fd, vpnNetworks)
|
||||
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return newTunFromFd(c, l, *fd, vpnNetworks, unsafeNetworks)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -129,3 +131,85 @@ func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, er
|
|||
|
||||
return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest)
|
||||
}
|
||||
|
||||
// genLinkLocal generates a random IPv4 link-local address.
|
||||
// If randomizer is nil, it uses rand.Reader to find two random bytes
|
||||
func genLinkLocal(randomizer io.Reader) netip.Prefix {
|
||||
if randomizer == nil {
|
||||
randomizer = rand.Reader
|
||||
}
|
||||
octets := []byte{169, 254, 0, 0}
|
||||
_, _ = randomizer.Read(octets[2:4])
|
||||
return coerceLinkLocal(octets)
|
||||
}
|
||||
|
||||
func coerceLinkLocal(octets []byte) netip.Prefix {
|
||||
if octets[3] == 0 {
|
||||
octets[3] = 1 //please no .0 addresses
|
||||
} else if octets[2] == 255 && octets[3] == 255 {
|
||||
octets[3] = 254 //please no broadcast addresses
|
||||
}
|
||||
out, _ := netip.AddrFromSlice(octets)
|
||||
return netip.PrefixFrom(out, 32)
|
||||
}
|
||||
|
||||
// prepareUnsafeOriginAddr provides the IPv4 address used on IPv6-only clients that need to access IPv4 unsafe routes
|
||||
func prepareUnsafeOriginAddr(d Device, l *logrus.Logger, c *config.C, routes []Route) netip.Prefix {
|
||||
if !d.Networks()[0].Addr().Is6() {
|
||||
return netip.Prefix{} //if we have an IPv4 assignment within the overlay, we don't need an unsafe origin address
|
||||
}
|
||||
|
||||
needed := false
|
||||
for _, route := range routes { //or if we have a route defined into an IPv4 range
|
||||
if route.Cidr.Addr().Is4() {
|
||||
needed = true //todo should this only apply to unsafe routes? almost certainly
|
||||
break
|
||||
}
|
||||
}
|
||||
if !needed {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
//todo better config name for sure
|
||||
if a := c.GetString("tun.unsafe_origin_address_for_4over6", ""); a != "" {
|
||||
out, err := netip.ParseAddr(a)
|
||||
if err != nil {
|
||||
l.WithField("value", a).WithError(err).Warn("failed to parse tun.unsafe_origin_address_for_4over6, will use a random value")
|
||||
} else if !out.Is4() || !out.IsLinkLocalUnicast() {
|
||||
l.WithField("value", out).Warn("tun.unsafe_origin_address_for_4over6 must be an IPv4 address")
|
||||
} else if out.IsValid() {
|
||||
return netip.PrefixFrom(out, 32)
|
||||
}
|
||||
}
|
||||
return genLinkLocal(nil)
|
||||
}
|
||||
|
||||
// prepareSnatAddr provides the address that an IPv6-only unsafe router should use to SNAT traffic before handing it to the operating system
|
||||
func prepareSnatAddr(d Device, l *logrus.Logger, c *config.C) netip.Prefix {
|
||||
if !d.Networks()[0].Addr().Is6() {
|
||||
return netip.Prefix{} //if we have an IPv4 assignment within the overlay, we don't need a snat address
|
||||
}
|
||||
|
||||
needed := false
|
||||
for _, un := range d.UnsafeNetworks() { //if we are an unsafe router for an IPv4 range
|
||||
if un.Addr().Is4() {
|
||||
needed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !needed {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
if a := c.GetString("tun.snat_address_for_4over6", ""); a != "" {
|
||||
out, err := netip.ParseAddr(a)
|
||||
if err != nil {
|
||||
l.WithField("value", a).WithError(err).Warn("failed to parse tun.snat_address_for_4over6, will use a random value")
|
||||
} else if !out.Is4() || !out.IsLinkLocalUnicast() {
|
||||
l.WithField("value", out).Warn("tun.snat_address_for_4over6 must be an IPv4 address")
|
||||
} else if out.IsValid() {
|
||||
return netip.PrefixFrom(out, 32)
|
||||
}
|
||||
}
|
||||
return genLinkLocal(nil)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,14 +19,16 @@ import (
|
|||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
fd int
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
fd int
|
||||
vpnNetworks []netip.Prefix
|
||||
unsafeNetworks []netip.Prefix
|
||||
unsafeIPv4Origin netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
|
||||
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
||||
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
|
@ -35,6 +37,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
|||
ReadWriteCloser: file,
|
||||
fd: deviceFd,
|
||||
vpnNetworks: vpnNetworks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
l: l,
|
||||
}
|
||||
|
||||
|
|
@ -53,7 +56,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
|||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in Android")
|
||||
}
|
||||
|
||||
|
|
@ -76,6 +79,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -91,6 +96,18 @@ func (t *tun) Networks() []netip.Prefix {
|
|||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeNetworks() []netip.Prefix {
|
||||
return t.unsafeNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
|
||||
return t.unsafeIPv4Origin
|
||||
}
|
||||
|
||||
func (t *tun) SNATAddress() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
return "android"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,13 +24,15 @@ import (
|
|||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
DefaultMTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
unsafeNetworks []netip.Prefix
|
||||
unsafeIPv4Origin netip.Prefix
|
||||
DefaultMTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
out []byte
|
||||
|
|
@ -79,7 +81,7 @@ type ifreqAlias6 struct {
|
|||
Lifetime addrLifetime
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
name := c.GetString("tun.dev", "")
|
||||
ifIndex := -1
|
||||
if name != "" && name != "utun" {
|
||||
|
|
@ -127,6 +129,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||
ReadWriteCloser: os.NewFile(uintptr(fd), ""),
|
||||
Device: name,
|
||||
vpnNetworks: vpnNetworks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
DefaultMTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
|
@ -153,7 +156,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
|||
return
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||
}
|
||||
|
||||
|
|
@ -213,6 +216,11 @@ func (t *tun) Activate() error {
|
|||
}
|
||||
}
|
||||
}
|
||||
if t.unsafeIPv4Origin.IsValid() && t.unsafeIPv4Origin.Addr().Is4() {
|
||||
if err = t.activate4(t.unsafeIPv4Origin); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Run the interface
|
||||
ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING
|
||||
|
|
@ -314,6 +322,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
if initial {
|
||||
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -545,6 +557,18 @@ func (t *tun) Networks() []netip.Prefix {
|
|||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeNetworks() []netip.Prefix {
|
||||
return t.unsafeNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
|
||||
return t.unsafeIPv4Origin
|
||||
}
|
||||
|
||||
func (t *tun) SNATAddress() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,6 +52,17 @@ func (t *disabledTun) Networks() []netip.Prefix {
|
|||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (*disabledTun) UnsafeNetworks() []netip.Prefix {
|
||||
return nil
|
||||
}
|
||||
func (*disabledTun) SNATAddress() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
func (*disabledTun) UnsafeIPv4OriginAddress() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
func (*disabledTun) Name() string {
|
||||
return "disabled"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,14 +86,16 @@ type ifreqAlias6 struct {
|
|||
}
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
devFd int
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
unsafeNetworks []netip.Prefix
|
||||
unsafeIPv4Origin netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
linkAddr *netroute.LinkAddr
|
||||
l *logrus.Logger
|
||||
devFd int
|
||||
}
|
||||
|
||||
func (t *tun) Read(to []byte) (int, error) {
|
||||
|
|
@ -199,11 +201,11 @@ func (t *tun) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open existing tun device
|
||||
var fd int
|
||||
var err error
|
||||
|
|
@ -270,11 +272,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||
}
|
||||
|
||||
t := &tun{
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
devFd: fd,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
devFd: fd,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
|
|
@ -410,6 +413,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
if initial {
|
||||
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -446,6 +453,18 @@ func (t *tun) Networks() []netip.Prefix {
|
|||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeNetworks() []netip.Prefix {
|
||||
return t.unsafeNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
|
||||
return t.unsafeIPv4Origin
|
||||
}
|
||||
|
||||
func (t *tun) SNATAddress() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,20 +22,23 @@ import (
|
|||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
vpnNetworks []netip.Prefix
|
||||
unsafeNetworks []netip.Prefix
|
||||
unsafeIPv4Origin netip.Prefix
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||
t := &tun{
|
||||
vpnNetworks: vpnNetworks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
ReadWriteCloser: &tunReadCloser{f: file},
|
||||
l: l,
|
||||
}
|
||||
|
|
@ -69,6 +72,8 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -147,6 +152,18 @@ func (t *tun) Networks() []netip.Prefix {
|
|||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeNetworks() []netip.Prefix {
|
||||
return t.unsafeNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
|
||||
return t.unsafeIPv4Origin
|
||||
}
|
||||
|
||||
func (t *tun) SNATAddress() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
return "iOS"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,14 +26,15 @@ import (
|
|||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
fd int
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
fd int
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
unsafeNetworks []netip.Prefix
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
|
|
@ -46,6 +47,9 @@ type tun struct {
|
|||
routesFromSystem map[netip.Prefix]routing.Gateways
|
||||
routesFromSystemLock sync.Mutex
|
||||
|
||||
snatAddr netip.Prefix
|
||||
unsafeIPv4Origin netip.Prefix
|
||||
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
|
|
@ -53,6 +57,18 @@ func (t *tun) Networks() []netip.Prefix {
|
|||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeNetworks() []netip.Prefix {
|
||||
return t.unsafeNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
|
||||
return t.unsafeIPv4Origin
|
||||
}
|
||||
|
||||
func (t *tun) SNATAddress() netip.Prefix {
|
||||
return t.snatAddr
|
||||
}
|
||||
|
||||
type ifReq struct {
|
||||
Name [16]byte
|
||||
Flags uint16
|
||||
|
|
@ -71,10 +87,10 @@ type ifreqQLEN struct {
|
|||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -84,7 +100,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
|||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||
|
|
@ -123,7 +139,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -133,11 +149,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
|||
return t, nil
|
||||
}
|
||||
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
vpnNetworks: vpnNetworks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||
|
|
@ -170,6 +187,11 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
if initial {
|
||||
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes) //todo MUST be different from t.snatAddr!
|
||||
t.snatAddr = prepareSnatAddr(t, t.l, c)
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, true)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -313,6 +335,17 @@ func (t *tun) addIPs(link netlink.Link) error {
|
|||
}
|
||||
}
|
||||
|
||||
if t.unsafeIPv4Origin.IsValid() {
|
||||
newAddrs = append(newAddrs, &netlink.Addr{
|
||||
IPNet: &net.IPNet{
|
||||
IP: t.unsafeIPv4Origin.Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(t.unsafeIPv4Origin.Bits(), t.unsafeIPv4Origin.Addr().BitLen()),
|
||||
},
|
||||
Label: t.unsafeIPv4Origin.Addr().Zone(),
|
||||
})
|
||||
t.l.WithField("address", t.unsafeIPv4Origin).Info("Adding origin address for IPv4 unsafe_routes")
|
||||
}
|
||||
|
||||
//add all new addresses
|
||||
for i := range newAddrs {
|
||||
//AddrReplace still adds new IPs, but if their properties change it will change them as well
|
||||
|
|
@ -400,7 +433,13 @@ func (t *tun) Activate() error {
|
|||
//set route MTU
|
||||
for i := range t.vpnNetworks {
|
||||
if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil {
|
||||
return fmt.Errorf("failed to set default route MTU: %w", err)
|
||||
return fmt.Errorf("failed to set default route MTU for %s: %w", t.vpnNetworks[i], err)
|
||||
}
|
||||
}
|
||||
|
||||
if t.unsafeIPv4Origin.IsValid() {
|
||||
if err = t.setDefaultRoute(t.unsafeIPv4Origin); err != nil {
|
||||
return fmt.Errorf("failed to set default route MTU for %s: %w", t.unsafeIPv4Origin, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -427,6 +466,23 @@ func (t *tun) setMTU() {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *tun) setSnatRoute() error {
|
||||
dr := &net.IPNet{
|
||||
IP: t.snatAddr.Masked().Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(t.snatAddr.Bits(), t.snatAddr.Addr().BitLen()),
|
||||
}
|
||||
|
||||
nr := netlink.Route{
|
||||
LinkIndex: t.deviceIndex,
|
||||
Dst: dr,
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
//Protocol: unix.RTPROT_KERNEL,
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
Type: unix.RTN_UNICAST,
|
||||
}
|
||||
return netlink.RouteReplace(&nr)
|
||||
}
|
||||
|
||||
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
||||
dr := &net.IPNet{
|
||||
IP: cidr.Masked().Addr().AsSlice(),
|
||||
|
|
@ -503,6 +559,13 @@ func (t *tun) addRoutes(logErrors bool) error {
|
|||
}
|
||||
}
|
||||
|
||||
if t.snatAddr.IsValid() {
|
||||
//at least for Linux, we need to set a return route for the SNATted traffic in order to satisfy the reverse-path filter,
|
||||
//and to help the kernel deliver our reply traffic to the tun device.
|
||||
//however, it is important that we do not actually /assign/ the SNAT address,
|
||||
//since link-local addresses will not be routed between interfaces without significant trickery.
|
||||
return t.setSnatRoute()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -58,23 +58,25 @@ type addrLifetime struct {
|
|||
}
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
f *os.File
|
||||
fd int
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
unsafeNetworks []netip.Prefix
|
||||
unsafeIPv4Origin netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
f *os.File
|
||||
fd int
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
|
|
@ -350,6 +352,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
if initial {
|
||||
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -386,6 +392,18 @@ func (t *tun) Networks() []netip.Prefix {
|
|||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeNetworks() []netip.Prefix {
|
||||
return t.unsafeNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
|
||||
return t.unsafeIPv4Origin
|
||||
}
|
||||
|
||||
func (t *tun) SNATAddress() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,25 +49,27 @@ type ifreq struct {
|
|||
}
|
||||
|
||||
type tun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
f *os.File
|
||||
fd int
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
unsafeNetworks []netip.Prefix
|
||||
unsafeIPv4Origin netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
f *os.File
|
||||
fd int
|
||||
// cache out buffer since we need to prepend 4 bytes for tun metadata
|
||||
out []byte
|
||||
}
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
|
|
@ -89,12 +91,13 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||
}
|
||||
|
||||
t := &tun{
|
||||
f: os.NewFile(uintptr(fd), ""),
|
||||
fd: fd,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
f: os.NewFile(uintptr(fd), ""),
|
||||
fd: fd,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
|
|
@ -270,6 +273,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
if initial {
|
||||
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -306,6 +313,18 @@ func (t *tun) Networks() []netip.Prefix {
|
|||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeNetworks() []netip.Prefix {
|
||||
return t.unsafeNetworks
|
||||
}
|
||||
|
||||
func (t *tun) UnsafeIPv4OriginAddress() netip.Prefix {
|
||||
return t.unsafeIPv4Origin
|
||||
}
|
||||
|
||||
func (t *tun) SNATAddress() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
func (t *tun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
|
|
|||
179
overlay/tun_snat_test.go
Normal file
179
overlay/tun_snat_test.go
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
package overlay
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/routing"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockDevice is a minimal Device implementation for testing prepareUnsafeOriginAddr.
|
||||
type mockDevice struct {
|
||||
networks []netip.Prefix
|
||||
unsafeNetworks []netip.Prefix
|
||||
snatAddr netip.Prefix
|
||||
unsafeSnatAddr netip.Prefix
|
||||
}
|
||||
|
||||
func (d *mockDevice) Read([]byte) (int, error) { return 0, nil }
|
||||
func (d *mockDevice) Write([]byte) (int, error) { return 0, nil }
|
||||
func (d *mockDevice) Close() error { return nil }
|
||||
func (d *mockDevice) Activate() error { return nil }
|
||||
func (d *mockDevice) Networks() []netip.Prefix { return d.networks }
|
||||
func (d *mockDevice) UnsafeNetworks() []netip.Prefix { return d.unsafeNetworks }
|
||||
func (d *mockDevice) SNATAddress() netip.Prefix { return d.snatAddr }
|
||||
func (d *mockDevice) UnsafeIPv4OriginAddress() netip.Prefix { return d.unsafeSnatAddr }
|
||||
func (d *mockDevice) Name() string { return "mock" }
|
||||
func (d *mockDevice) RoutesFor(netip.Addr) routing.Gateways { return routing.Gateways{} }
|
||||
func (d *mockDevice) SupportsMultiqueue() bool { return false }
|
||||
func (d *mockDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, nil }
|
||||
|
||||
func TestPrepareSnatAddr_V4Primary_NoSnat(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l.SetLevel(logrus.PanicLevel)
|
||||
c := config.NewC(l)
|
||||
|
||||
// If the device has an IPv4 primary address, no SNAT needed
|
||||
d := &mockDevice{
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
|
||||
}
|
||||
result := prepareUnsafeOriginAddr(d, l, c, nil)
|
||||
assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr when device has IPv4 primary")
|
||||
}
|
||||
|
||||
func TestPrepareSnatAddr_V6Primary_NoUnsafeOrRoutes(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l.SetLevel(logrus.PanicLevel)
|
||||
c := config.NewC(l)
|
||||
|
||||
// IPv6 primary but no unsafe networks or IPv4 routes
|
||||
d := &mockDevice{
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
|
||||
}
|
||||
result := prepareUnsafeOriginAddr(d, l, c, nil)
|
||||
assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr without IPv4 unsafe networks or routes")
|
||||
}
|
||||
|
||||
func TestPrepareSnatAddr_V6Primary_WithV4Unsafe(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l.SetLevel(logrus.PanicLevel)
|
||||
c := config.NewC(l)
|
||||
|
||||
// IPv6 primary with IPv4 unsafe network -> should get SNAT addr
|
||||
d := &mockDevice{
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
|
||||
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||
}
|
||||
result := prepareSnatAddr(d, l, c)
|
||||
require.True(t, result.IsValid(), "should assign SNAT addr")
|
||||
assert.True(t, result.Addr().Is4(), "SNAT addr should be IPv4")
|
||||
assert.True(t, result.Addr().IsLinkLocalUnicast(), "SNAT addr should be link-local")
|
||||
assert.Equal(t, 32, result.Bits(), "SNAT addr should be /32")
|
||||
|
||||
result = prepareUnsafeOriginAddr(d, l, c, nil)
|
||||
require.False(t, result.IsValid(), "no routes = no origin addr needed")
|
||||
}
|
||||
|
||||
func TestPrepareUnsafeOriginAddr_V6Primary_WithV4Route(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l.SetLevel(logrus.PanicLevel)
|
||||
c := config.NewC(l)
|
||||
|
||||
// IPv6 primary with IPv4 route -> should get SNAT addr
|
||||
d := &mockDevice{
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
|
||||
}
|
||||
routes := []Route{
|
||||
{Cidr: netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
result := prepareUnsafeOriginAddr(d, l, c, routes)
|
||||
require.True(t, result.IsValid(), "should assign SNAT addr when IPv4 route exists")
|
||||
assert.True(t, result.Addr().Is4())
|
||||
assert.True(t, result.Addr().IsLinkLocalUnicast())
|
||||
|
||||
result = prepareSnatAddr(d, l, c)
|
||||
require.False(t, result.IsValid(), "no UnsafeNetworks = no snat addr needed")
|
||||
}
|
||||
|
||||
func TestPrepareSnatAddr_V6Primary_V6UnsafeOnly(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l.SetLevel(logrus.PanicLevel)
|
||||
c := config.NewC(l)
|
||||
|
||||
// IPv6 primary with only IPv6 unsafe network -> no SNAT needed
|
||||
d := &mockDevice{
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
|
||||
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("fd01::/64")},
|
||||
}
|
||||
result := prepareUnsafeOriginAddr(d, l, c, nil)
|
||||
assert.Equal(t, netip.Prefix{}, result, "should not assign SNAT addr for IPv6-only unsafe networks")
|
||||
}
|
||||
|
||||
func TestPrepareSnatAddr_ManualAddress(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l.SetLevel(logrus.PanicLevel)
|
||||
c := config.NewC(l)
|
||||
c.Settings["tun"] = map[string]any{
|
||||
"snat_address_for_4over6": "169.254.42.42",
|
||||
}
|
||||
|
||||
d := &mockDevice{
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
|
||||
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||
}
|
||||
result := prepareSnatAddr(d, l, c)
|
||||
require.True(t, result.IsValid())
|
||||
assert.Equal(t, netip.MustParseAddr("169.254.42.42"), result.Addr())
|
||||
assert.Equal(t, 32, result.Bits())
|
||||
}
|
||||
|
||||
func TestPrepareSnatAddr_InvalidManualAddress_Fallback(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l.SetLevel(logrus.PanicLevel)
|
||||
c := config.NewC(l)
|
||||
c.Settings["tun"] = map[string]any{
|
||||
"snat_address_for_4over6": "not-an-ip",
|
||||
}
|
||||
|
||||
d := &mockDevice{
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
|
||||
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||
}
|
||||
result := prepareSnatAddr(d, l, c)
|
||||
// Should fall back to auto-assignment
|
||||
require.True(t, result.IsValid(), "should fall back to auto-assigned address")
|
||||
assert.True(t, result.Addr().Is4())
|
||||
assert.True(t, result.Addr().IsLinkLocalUnicast())
|
||||
}
|
||||
|
||||
func TestPrepareSnatAddr_AutoGenerated_Range(t *testing.T) {
|
||||
l := logrus.New()
|
||||
l.SetLevel(logrus.PanicLevel)
|
||||
c := config.NewC(l)
|
||||
|
||||
d := &mockDevice{
|
||||
networks: []netip.Prefix{netip.MustParsePrefix("fd00::1/128")},
|
||||
unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||
}
|
||||
|
||||
// Generate several addresses and verify they're all in the expected range
|
||||
for i := 0; i < 100; i++ {
|
||||
result := prepareSnatAddr(d, l, c)
|
||||
require.True(t, result.IsValid())
|
||||
addr := result.Addr()
|
||||
octets := addr.As4()
|
||||
assert.Equal(t, byte(169), octets[0], "first octet should be 169")
|
||||
assert.Equal(t, byte(254), octets[1], "second octet should be 254")
|
||||
// Should not have .0 in the last octet
|
||||
assert.NotEqual(t, byte(0), octets[3], "last octet should not be 0")
|
||||
// Should not be 169.254.255.255 (broadcast)
|
||||
if octets[2] == 255 {
|
||||
assert.NotEqual(t, byte(255), octets[3], "should not be broadcast address")
|
||||
}
|
||||
}
|
||||
}
|
||||
37
overlay/tun_test.go
Normal file
37
overlay/tun_test.go
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
package overlay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLinkLocal(t *testing.T) {
|
||||
r := bytes.NewReader([]byte{42, 99})
|
||||
result := genLinkLocal(r)
|
||||
assert.Equal(t, netip.MustParsePrefix("169.254.42.99/32"), result, "genLinkLocal with a deterministic randomizer")
|
||||
|
||||
result = genLinkLocal(nil)
|
||||
assert.True(t, result.IsValid(), "genLinkLocal with nil randomizer should be valid")
|
||||
assert.True(t, result.Addr().IsLinkLocalUnicast(), "genLinkLocal with nil randomizer should be link-local")
|
||||
|
||||
result = coerceLinkLocal([]byte{169, 254, 100, 50})
|
||||
assert.Equal(t, netip.MustParsePrefix("169.254.100.50/32"), result, "coerceLinkLocal should pass through normal values")
|
||||
|
||||
result = coerceLinkLocal([]byte{169, 254, 0, 0})
|
||||
assert.Equal(t, netip.MustParsePrefix("169.254.0.1/32"), result, "coerceLinkLocal should bump .0 last octet to .1")
|
||||
|
||||
result = coerceLinkLocal([]byte{169, 254, 255, 255})
|
||||
assert.Equal(t, netip.MustParsePrefix("169.254.255.254/32"), result, "coerceLinkLocal should bump broadcast 255.255 to 255.254")
|
||||
|
||||
result = coerceLinkLocal([]byte{169, 254, 0, 1})
|
||||
assert.Equal(t, netip.MustParsePrefix("169.254.0.1/32"), result, "coerceLinkLocal should leave .1 last octet unchanged")
|
||||
|
||||
result = coerceLinkLocal([]byte{169, 254, 255, 254})
|
||||
assert.Equal(t, netip.MustParsePrefix("169.254.255.254/32"), result, "coerceLinkLocal should leave 255.254 unchanged")
|
||||
|
||||
result = coerceLinkLocal([]byte{169, 254, 255, 100})
|
||||
assert.Equal(t, netip.MustParsePrefix("169.254.255.100/32"), result, "coerceLinkLocal should leave 255.100 unchanged")
|
||||
}
|
||||
|
|
@ -17,18 +17,21 @@ import (
|
|||
)
|
||||
|
||||
type TestTun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
Routes []Route
|
||||
routeTree *bart.Table[routing.Gateways]
|
||||
l *logrus.Logger
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
unsafeNetworks []netip.Prefix
|
||||
snatAddr netip.Prefix
|
||||
unsafeIPv4Origin netip.Prefix
|
||||
Routes []Route
|
||||
routeTree *bart.Table[routing.Gateways]
|
||||
l *logrus.Logger
|
||||
|
||||
closed atomic.Bool
|
||||
rxPackets chan []byte // Packets to receive into nebula
|
||||
TxPackets chan []byte // Packets transmitted outside by nebula
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -38,18 +41,22 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &TestTun{
|
||||
Device: c.GetString("tun.dev", ""),
|
||||
vpnNetworks: vpnNetworks,
|
||||
Routes: routes,
|
||||
routeTree: routeTree,
|
||||
l: l,
|
||||
rxPackets: make(chan []byte, 10),
|
||||
TxPackets: make(chan []byte, 10),
|
||||
}, nil
|
||||
tt := &TestTun{
|
||||
Device: c.GetString("tun.dev", ""),
|
||||
vpnNetworks: vpnNetworks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
Routes: routes,
|
||||
routeTree: routeTree,
|
||||
l: l,
|
||||
rxPackets: make(chan []byte, 10),
|
||||
TxPackets: make(chan []byte, 10),
|
||||
}
|
||||
tt.unsafeIPv4Origin = prepareUnsafeOriginAddr(tt, l, c, routes)
|
||||
tt.snatAddr = prepareSnatAddr(tt, tt.l, c)
|
||||
return tt, nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*TestTun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported")
|
||||
}
|
||||
|
||||
|
|
@ -139,3 +146,15 @@ func (t *TestTun) SupportsMultiqueue() bool {
|
|||
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
|
||||
return nil, fmt.Errorf("TODO: multiqueue not implemented")
|
||||
}
|
||||
|
||||
func (t *TestTun) UnsafeNetworks() []netip.Prefix {
|
||||
return t.unsafeNetworks
|
||||
}
|
||||
|
||||
func (t *TestTun) UnsafeIPv4OriginAddress() netip.Prefix {
|
||||
return t.unsafeIPv4Origin
|
||||
}
|
||||
|
||||
func (t *TestTun) SNATAddress() netip.Prefix {
|
||||
return t.snatAddr
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,21 +28,23 @@ import (
|
|||
const tunGUIDLabel = "Fixed Nebula Windows GUID v1"
|
||||
|
||||
type winTun struct {
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
unsafeNetworks []netip.Prefix
|
||||
unsafeIPv4Origin netip.Prefix
|
||||
MTU int
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
l *logrus.Logger
|
||||
|
||||
tun *wintun.NativeTun
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (Device, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
||||
err := checkWinTunExists()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
|
||||
|
|
@ -55,10 +57,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
|||
}
|
||||
|
||||
t := &winTun{
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
Device: deviceName,
|
||||
vpnNetworks: vpnNetworks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
MTU: c.GetInt("tun.mtu", DefaultMTU),
|
||||
l: l,
|
||||
}
|
||||
|
||||
err = t.reload(c, true)
|
||||
|
|
@ -102,6 +105,10 @@ func (t *winTun) reload(c *config.C, initial bool) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
if initial {
|
||||
t.unsafeIPv4Origin = prepareUnsafeOriginAddr(t, t.l, c, routes)
|
||||
}
|
||||
|
||||
routeTree, err := makeRouteTree(t.l, routes, false)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -132,7 +139,12 @@ func (t *winTun) reload(c *config.C, initial bool) error {
|
|||
func (t *winTun) Activate() error {
|
||||
luid := winipcfg.LUID(t.tun.LUID())
|
||||
|
||||
err := luid.SetIPAddresses(t.vpnNetworks)
|
||||
prefixes := t.vpnNetworks
|
||||
if t.unsafeIPv4Origin.IsValid() {
|
||||
prefixes = append(prefixes, t.unsafeIPv4Origin)
|
||||
}
|
||||
|
||||
err := luid.SetIPAddresses(prefixes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set address: %w", err)
|
||||
}
|
||||
|
|
@ -225,6 +237,18 @@ func (t *winTun) Networks() []netip.Prefix {
|
|||
return t.vpnNetworks
|
||||
}
|
||||
|
||||
func (t *winTun) UnsafeNetworks() []netip.Prefix {
|
||||
return t.unsafeNetworks
|
||||
}
|
||||
|
||||
func (t *winTun) UnsafeIPv4OriginAddress() netip.Prefix {
|
||||
return t.unsafeIPv4Origin
|
||||
}
|
||||
|
||||
func (t *winTun) SNATAddress() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
func (t *winTun) Name() string {
|
||||
return t.Device
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import (
|
|||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return NewUserDevice(vpnNetworks)
|
||||
}
|
||||
|
||||
|
|
@ -36,6 +36,17 @@ type UserDevice struct {
|
|||
inboundWriter *io.PipeWriter
|
||||
}
|
||||
|
||||
func (d *UserDevice) UnsafeNetworks() []netip.Prefix {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *UserDevice) SNATAddress() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
func (d *UserDevice) UnsafeIPv4OriginAddress() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
func (d *UserDevice) Activate() error {
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
47
pki.go
47
pki.go
|
|
@ -91,7 +91,7 @@ func (p *PKI) reload(c *config.C, initial bool) error {
|
|||
}
|
||||
|
||||
func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
newState, err := newCertStateFromConfig(c)
|
||||
newState, err := newCertStateFromConfig(c, p.l)
|
||||
if err != nil {
|
||||
return util.NewContextualError("Could not load client cert", nil, err)
|
||||
}
|
||||
|
|
@ -102,7 +102,7 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||
if currentState.v1Cert == nil {
|
||||
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
||||
} else {
|
||||
// did IP in cert change? if so, don't set
|
||||
// did IP in cert change? if so, don't set. If we ever allow this, need to set p.firewallReloadNeeded
|
||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
|
|
@ -158,6 +158,14 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
|||
}
|
||||
}
|
||||
|
||||
newUN := newState.GetDefaultCertificate().UnsafeNetworks()
|
||||
oldUN := currentState.GetDefaultCertificate().UnsafeNetworks()
|
||||
if !slices.Equal(newUN, oldUN) {
|
||||
//todo I don't love this, because other clients will see the new assignments and act on them, but we will not be able to.
|
||||
//I think we need to wire this into the firewall reload.
|
||||
p.l.WithFields(m{"previous": oldUN, "new": newUN}).Warning("UnsafeNetworks assignments differ. A restart is required in order for this to take effect.")
|
||||
}
|
||||
|
||||
// Cipher cant be hot swapped so just leave it at what it was before
|
||||
newState.cipher = currentState.cipher
|
||||
|
||||
|
|
@ -260,7 +268,7 @@ func (cs *CertState) MarshalJSON() ([]byte, error) {
|
|||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
||||
func newCertStateFromConfig(c *config.C, l *logrus.Logger) (*CertState, error) {
|
||||
var err error
|
||||
|
||||
privPathOrPEM := c.GetString("pki.key", "")
|
||||
|
|
@ -344,10 +352,33 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
|||
return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion)
|
||||
}
|
||||
|
||||
return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey)
|
||||
return newCertState(l, initiatingVersion, v1, v2, isPkcs11, curve, rawKey)
|
||||
}
|
||||
|
||||
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
|
||||
func compareUnsafeNetworksAcrossCertVersions(v1, v2 cert.Certificate) error {
|
||||
if v1 == nil || v2 == nil {
|
||||
return nil //can't be a problem if we don't have one of the kinds of cert
|
||||
}
|
||||
|
||||
v4UnsafeNets := 0
|
||||
for _, n := range v2.UnsafeNetworks() {
|
||||
if n.Addr().Is6() {
|
||||
continue // V1 certs can't have IPv6 unsafe networks
|
||||
} else {
|
||||
v4UnsafeNets++
|
||||
}
|
||||
if !slices.Contains(v1.UnsafeNetworks(), n) {
|
||||
return errors.New("UnsafeNetworks mismatch")
|
||||
}
|
||||
}
|
||||
if len(v1.UnsafeNetworks()) != v4UnsafeNets {
|
||||
return errors.New("UnsafeNetworks mismatch")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newCertState(l *logrus.Logger, dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
|
||||
cs := CertState{
|
||||
privateKey: privateKey,
|
||||
pkcs11Backed: pkcs11backed,
|
||||
|
|
@ -370,6 +401,12 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
|||
}
|
||||
|
||||
cs.initiatingVersion = dv
|
||||
|
||||
warn := compareUnsafeNetworksAcrossCertVersions(v1, v2)
|
||||
if warn != nil {
|
||||
l.WithFields(m{"UnsafeNetworksV1": v1.UnsafeNetworks(), "UnsafeNetworksV2": v2.UnsafeNetworks()}).
|
||||
Warning("the IPv4 UnsafeNetworks assigned in the V1 certificate do not match the ones in V2")
|
||||
}
|
||||
}
|
||||
|
||||
if v1 != nil {
|
||||
|
|
|
|||
91
snat.go
Normal file
91
snat.go
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func recalcIPv4Checksum(data []byte, oldSrcIP netip.Addr, newSrcIP netip.Addr) {
|
||||
oldChecksum := binary.BigEndian.Uint16(data[10:12])
|
||||
//because of how checksums work, we can re-use this function
|
||||
checksum := calcNewTransportChecksum(oldChecksum, oldSrcIP, 0, newSrcIP, 0)
|
||||
binary.BigEndian.PutUint16(data[10:12], checksum)
|
||||
}
|
||||
|
||||
func calcNewTransportChecksum(oldChecksum uint16, oldSrcIP netip.Addr, oldSrcPort uint16, newSrcIP netip.Addr, newSrcPort uint16) uint16 {
|
||||
oldIP := binary.BigEndian.Uint32(oldSrcIP.AsSlice())
|
||||
newIP := binary.BigEndian.Uint32(newSrcIP.AsSlice())
|
||||
|
||||
// Start with inverted checksum
|
||||
checksum := uint32(^oldChecksum)
|
||||
|
||||
// Subtract old IP (as two 16-bit words)
|
||||
checksum += uint32(^uint16(oldIP >> 16))
|
||||
checksum += uint32(^uint16(oldIP & 0xFFFF))
|
||||
|
||||
// Subtract old port
|
||||
checksum += uint32(^oldSrcPort)
|
||||
|
||||
// Add new IP (as two 16-bit words)
|
||||
checksum += uint32(newIP >> 16)
|
||||
checksum += uint32(newIP & 0xFFFF)
|
||||
|
||||
// Add new port
|
||||
checksum += uint32(newSrcPort)
|
||||
|
||||
// Fold carries
|
||||
for checksum > 0xFFFF {
|
||||
checksum = (checksum & 0xFFFF) + (checksum >> 16)
|
||||
}
|
||||
|
||||
// Return ones' complement
|
||||
return ^uint16(checksum)
|
||||
}
|
||||
|
||||
func recalcV4TransportChecksum(offsetInsideHeader int, data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
|
||||
ipHeaderOffset := int(data[0]&0x0F) * 4
|
||||
offset := ipHeaderOffset + offsetInsideHeader
|
||||
oldcsum := binary.BigEndian.Uint16(data[offset : offset+2])
|
||||
checksum := calcNewTransportChecksum(oldcsum, oldSrcIP.Addr(), oldSrcIP.Port(), newSrcIP.Addr(), newSrcIP.Port())
|
||||
binary.BigEndian.PutUint16(data[offset:offset+2], checksum)
|
||||
}
|
||||
|
||||
func recalcUDPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
|
||||
const offsetInsideHeader = 6
|
||||
recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP)
|
||||
}
|
||||
|
||||
func recalcTCPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
|
||||
const offsetInsideHeader = 16
|
||||
recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP)
|
||||
}
|
||||
|
||||
func calcNewICMPChecksum(oldChecksum uint16, oldCode uint16, newCode uint16, oldID uint16, newID uint16) uint16 {
|
||||
// Start with inverted checksum
|
||||
checksum := uint32(^oldChecksum)
|
||||
|
||||
// Subtract old stuff
|
||||
checksum += uint32(^oldCode)
|
||||
checksum += uint32(^oldID)
|
||||
|
||||
// Add new stuff
|
||||
checksum += uint32(newCode)
|
||||
checksum += uint32(newID)
|
||||
|
||||
// Fold carries
|
||||
for checksum > 0xFFFF {
|
||||
checksum = (checksum & 0xFFFF) + (checksum >> 16)
|
||||
}
|
||||
|
||||
// Return ones' complement
|
||||
return ^uint16(checksum)
|
||||
}
|
||||
|
||||
func recalcICMPv4Checksum(data []byte, oldCode uint16, newCode uint16, oldID uint16, newID uint16) {
|
||||
const offsetInsideHeader = 2
|
||||
ipHeaderOffset := int(data[0]&0x0F) * 4
|
||||
offset := ipHeaderOffset + offsetInsideHeader
|
||||
oldChecksum := binary.BigEndian.Uint16(data[offset : offset+2])
|
||||
checksum := calcNewICMPChecksum(oldChecksum, oldCode, newCode, oldID, newID)
|
||||
binary.BigEndian.PutUint16(data[offset:offset+2], checksum)
|
||||
}
|
||||
1310
snat_test.go
Normal file
1310
snat_test.go
Normal file
File diff suppressed because it is too large
Load diff
12
test/tun.go
12
test/tun.go
|
|
@ -10,6 +10,18 @@ import (
|
|||
|
||||
type NoopTun struct{}
|
||||
|
||||
func (NoopTun) UnsafeNetworks() []netip.Prefix {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (NoopTun) SNATAddress() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
func (NoopTun) UnsafeIPv4OriginAddress() netip.Prefix {
|
||||
return netip.Prefix{}
|
||||
}
|
||||
|
||||
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
|
||||
return routing.Gateways{}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue