diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 4648c496..87d1ec11 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -441,7 +441,7 @@ func (c *certificateV2) validate() error { } } else if network.Addr().Is4() { if !hasV4Networks { - return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network) + //return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network) } } } diff --git a/firewall.go b/firewall.go index 45dc0691..457210a6 100644 --- a/firewall.go +++ b/firewall.go @@ -2,6 +2,7 @@ package nebula import ( "crypto/sha256" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -26,6 +27,19 @@ type FirewallInterface interface { AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error } +type snatInfo struct { + //Src is the source IP+port to write into unsafe-route-bound packet + Src netip.AddrPort + //SrcVpnIp is the overlay IP associated with this flow. It's needed to associate reply traffic so we can get it back to the right host. + SrcVpnIp netip.Addr + //SnatPort is the port to rewrite into an overlay-bound packet + SnatPort uint16 +} + +func (s *snatInfo) Valid() bool { + return s.Src.IsValid() +} + type conn struct { Expires time.Time // Time when this conntrack entry will expire @@ -34,6 +48,9 @@ type conn struct { // fields pack for free after the uint32 above incoming bool rulesVersion uint16 + + //for SNAT support + snat snatInfo } // TODO: need conntrack max tracked connections handling @@ -66,6 +83,7 @@ type Firewall struct { defaultLocalCIDRAny bool incomingMetrics firewallMetrics outgoingMetrics firewallMetrics + snatAddr netip.Addr l *logrus.Logger } @@ -149,12 +167,14 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D tmax = defaultTimeout } + hasV4Networks := false routableNetworks := new(bart.Lite) var assignedNetworks []netip.Prefix for _, network := range c.Networks() { nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) routableNetworks.Insert(nprefix) assignedNetworks = append(assignedNetworks, network) + hasV4Networks = hasV4Networks || network.Addr().Is4() } hasUnsafeNetworks := false @@ -163,6 +183,11 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D hasUnsafeNetworks = true } + snatAddr := netip.Addr{} + if hasUnsafeNetworks && !hasV4Networks { + snatAddr = netip.MustParseAddr("169.254.55.96") //todo this needs to come from the config, or perhaps the tun + } + return &Firewall{ Conntrack: &FirewallConntrack{ Conns: make(map[firewall.Packet]*conn), @@ -176,6 +201,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D routableNetworks: routableNetworks, assignedNetworks: assignedNetworks, hasUnsafeNetworks: hasUnsafeNetworks, + snatAddr: snatAddr, l: l, incomingMetrics: firewallMetrics{ @@ -401,22 +427,131 @@ var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses") var ErrNoMatchingRule = errors.New("no matching rule in firewall table") +func (f *Firewall) unSnat(data []byte, fp *firewall.Packet, c *conn) netip.Addr { + if c == nil { + //unfortunately this needs to lock. Surely there's a better way, but I need to make this flow at all first. + c = f.peek(*fp) + } + if c == nil { + return netip.Addr{} + } + if !c.snat.Valid() { + return netip.Addr{} + } + + oldIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort) + + //change dst IP + copy(data[16:], c.snat.Src.Addr().AsSlice()) + recalcIPv4Checksum(data, oldIP.Addr(), c.snat.Src.Addr()) + ipHeaderLen := int(data[0]&0x0F) * 4 + //dst port is at offset 2 + dstport := ipHeaderLen + 2 + + switch fp.Protocol { + case firewall.ProtoICMP: + binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], c.snat.Src.Port()) + icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on this yet (but Linux would) + recalcICMPv4Checksum(data, icmpCode, icmpCode, c.snat.SnatPort, c.snat.Src.Port()) + case firewall.ProtoUDP: + binary.BigEndian.PutUint16(data[dstport:dstport+2], c.snat.Src.Port()) + recalcUDPv4Checksum(data, oldIP, c.snat.Src) + case firewall.ProtoTCP: + binary.BigEndian.PutUint16(data[dstport:dstport+2], c.snat.Src.Port()) + recalcTCPv4Checksum(data, oldIP, c.snat.Src) + } + return c.snat.SrcVpnIp +} + +func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) { + if c.snat.Valid() { + //old flow + fp.RemoteAddr = f.snatAddr + fp.RemotePort = c.snat.SnatPort + } else if hostinfo.vpnAddrs[0].Is6() { + //we got a new flow + c.snat.Src = netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort) + c.snat.SrcVpnIp = hostinfo.vpnAddrs[0] + + fp.RemoteAddr = f.snatAddr + //find a new port to use, if needed + for { + existingFlow := f.peek(*fp) //locking and unlocking for each peek is slow, but simple for now + if existingFlow == nil { + break //yay, we can use this port + } + //increment and retry. There's probably better strategies out there + fp.RemotePort++ + if fp.RemotePort < 0x7ff { + fp.RemotePort += 0x7ff // keep it ephemeral for now + } //todo if we're totally out of ports this loops forever. Probably not good. + } + c.snat.SnatPort = fp.RemotePort + } else { + f.l.WithFields(logrus.Fields{ + "fp": *fp, + "conn": *c, + "hostinfo": hostinfo, + }).Error("this packet cannot be snatted") + return + } + + newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort) + //change src IP + copy(data[12:], f.snatAddr.AsSlice()) + recalcIPv4Checksum(data, c.snat.Src.Addr(), newIP.Addr()) + ipHeaderLen := int(data[0]&0x0F) * 4 + + switch fp.Protocol { + case firewall.ProtoICMP: + binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], c.snat.SnatPort) + icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on this yet (but Linux would) + recalcICMPv4Checksum(data, icmpCode, icmpCode, c.snat.Src.Port(), c.snat.SnatPort) + case firewall.ProtoUDP: + //src port is at offset 0 + binary.BigEndian.PutUint16(data[ipHeaderLen:ipHeaderLen+2], c.snat.SnatPort) + recalcUDPv4Checksum(data, c.snat.Src, newIP) + case firewall.ProtoTCP: + //src port is at offset 0 + binary.BigEndian.PutUint16(data[ipHeaderLen:ipHeaderLen+2], c.snat.SnatPort) + recalcTCPv4Checksum(data, c.snat.Src, newIP) + } +} + // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { +func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { + specialSnatMode := f.hasUnsafeNetworks && fp.IsIPv4() && h.HasOnlyV6Addresses() //todo I wish I only set this once somehow + + table := f.OutRules + if incoming { + table = f.InRules + } + // Check if we spoke to this tuple, if we did then allow this packet - if f.inConns(fp, h, caPool, localCache) { - return nil + // Check the cache first, iff not snatting + if localCache != nil && !specialSnatMode { + if _, ok := localCache[fp]; ok { + return nil //packet matched the cache, we're not snatting, we can return early! + } + } + c := f.inConns(fp, h, caPool, localCache) + if c != nil { + //can't return yet, need to snat maybe + goto snat } // Make sure remote address matches nebula certificate, and determine how to treat it if h.networks == nil { // Simple case: Certificate has one address and no unsafe networks if h.vpnAddrs[0] != fp.RemoteAddr { - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrInvalidRemoteIP - } + if !specialSnatMode { + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrInvalidRemoteIP //todo! + } //else we're in special snat mode, and we need to apply more checks below + } //else? all good, fall through } else { + //todo check for srcsnortaddr here too? nwType, ok := h.networks.Lookup(fp.RemoteAddr) if !ok { f.metrics(incoming).droppedRemoteAddr.Inc(1) @@ -426,9 +561,12 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * case NetworkTypeVPN: break // nothing special case NetworkTypeVPNPeer: + //todo we might need a specialSnatMode case in here to handle routers with v4 addresses when we don't also have a v4 address? f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrPeerRejected // reject for now, one day this may have different FW rules case NetworkTypeUnsafe: + //intentionally excluding f.hasUnsafeNetworks -- this is what lets routers talk back to us with our unsafe traffic! + specialSnatMode = fp.IsIPv4() && h.HasOnlyV6Addresses() && f.assignedNetworks[0].Addr().Is6() break // nothing special, one day this may have different FW rules default: f.metrics(incoming).droppedRemoteAddr.Inc(1) @@ -437,16 +575,12 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // Make sure we are supposed to be handling this local ip address - if !f.routableNetworks.Contains(fp.LocalAddr) { + //todo I'm not sure I trust this heuristic + if !specialSnatMode && !f.routableNetworks.Contains(fp.LocalAddr) { f.metrics(incoming).droppedLocalAddr.Inc(1) return ErrInvalidLocalIP } - table := f.OutRules - if incoming { - table = f.InRules - } - // We now know which firewall table to check against if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) { f.metrics(incoming).droppedNoRule.Inc(1) @@ -454,7 +588,16 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // We always want to conntrack since it is a faster operation - f.addConn(fp, incoming) + c = f.addConn(fp, incoming) + +snat: + if incoming { + if specialSnatMode { + //todo do not snat if you are not a router for the destination -- for now, just if you're not a router + f.applySnat(pkt, &fp, c, h) + f.dupeConn(fp, c) //track the snatted flow with the same expiration as the unsnatted version + } + } //outgoing snat is handled before this function is called (for now!) return nil } @@ -483,12 +626,23 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { - if localCache != nil { - if _, ok := localCache[fp]; ok { - return true - } +func (f *Firewall) peek(fp firewall.Packet) *conn { + conntrack := f.Conntrack + conntrack.Lock() + + // Purge every time we test + ep, has := conntrack.TimerWheel.Purge() + if has { + f.evict(ep) } + + c := conntrack.Conns[fp] + + conntrack.Unlock() + return c +} + +func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn { conntrack := f.Conntrack conntrack.Lock() @@ -500,9 +654,22 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, c, ok := conntrack.Conns[fp] + if !ok && fp.Protocol == firewall.ProtoICMP { + //todo this seems like it will also bite me + oldRemote := fp.RemotePort + oldLocal := fp.LocalPort + fp.RemotePort = 0 + fp.LocalPort = 0 + c, ok = conntrack.Conns[fp] + if ok { + fp.RemotePort = oldRemote + fp.LocalPort = oldLocal + } + } + if !ok { conntrack.Unlock() - return false + return nil } if c.rulesVersion != f.rulesVersion { @@ -525,7 +692,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, } delete(conntrack.Conns, fp) conntrack.Unlock() - return false + return nil } if f.l.Level >= logrus.DebugLevel { @@ -555,12 +722,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache[fp] = struct{}{} } - return true + return c } -func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { +func (f *Firewall) packetTimeout(fp firewall.Packet) time.Duration { var timeout time.Duration - c := &conn{} switch fp.Protocol { case firewall.ProtoTCP: @@ -570,7 +736,25 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { default: timeout = f.DefaultTimeout } + return timeout +} +func (f *Firewall) dupeConn(fp firewall.Packet, c *conn) { + conntrack := f.Conntrack + conntrack.Lock() + if _, ok := conntrack.Conns[fp]; !ok { + conntrack.TimerWheel.Advance(time.Now()) + conntrack.TimerWheel.Add(fp, f.packetTimeout(fp)) + } + + conntrack.Conns[fp] = c + conntrack.Unlock() +} + +func (f *Firewall) addConn(fp firewall.Packet, incoming bool) *conn { + c := &conn{} + + timeout := f.packetTimeout(fp) conntrack := f.Conntrack conntrack.Lock() if _, ok := conntrack.Conns[fp]; !ok { @@ -584,7 +768,18 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { c.rulesVersion = f.rulesVersion c.Expires = time.Now().Add(timeout) conntrack.Conns[fp] = c + + //todo will this bite me somehow? + if fp.Protocol == firewall.ProtoICMP { + //not required for ICMPv6 because we don't decode or SNAT it + //create a duplicate conntrack entry with all the port information zeroed? + fp.RemotePort = 0 + fp.LocalPort = 0 + conntrack.Conns[fp] = c + } + conntrack.Unlock() + return c } // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel @@ -662,6 +857,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer var port int32 + if p.Protocol == firewall.ProtoICMP { + // port numbers are re-used for connection tracking and SNAT, + // but we don't want to actually filter on them for ICMP + // ICMP6 is omitted because we don't attempt to parse code/identifier/etc out of ICMP6 + return fp[firewall.PortAny].match(p, c, caPool) + } + if p.Fragment { port = firewall.PortFragment } else if incoming { diff --git a/firewall/packet.go b/firewall/packet.go index 40c7fc5d..ce00129c 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -22,12 +22,20 @@ const ( type Packet struct { LocalAddr netip.Addr RemoteAddr netip.Addr - LocalPort uint16 + // LocalPort is the destination port for incoming traffic, or the source port for outgoing. + // For ICMP, it's "code". //todo also store "type?" would need to decode replies, which sucks + LocalPort uint16 + // RemotePort is the source port for incoming traffic, or the destination port for outgoing. + // For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier RemotePort uint16 Protocol uint8 Fragment bool } +func (fp *Packet) IsIPv4() bool { + return fp.LocalAddr.Is4() && fp.RemoteAddr.Is4() +} + func (fp *Packet) Copy() *Packet { return &Packet{ LocalAddr: fp.LocalAddr, diff --git a/firewall_test.go b/firewall_test.go index 1df62a81..77e7cc69 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -212,44 +212,44 @@ func TestFirewall_Drop(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr p.RemoteAddr = netip.MustParseAddr("1.2.3.10") - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func TestFirewall_DropV6(t *testing.T) { @@ -291,44 +291,44 @@ func TestFirewall_DropV6(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr p.RemoteAddr = netip.MustParseAddr("fd12::56") - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -536,10 +536,10 @@ func TestFirewall_Drop2(t *testing.T) { cp := cert.NewCAPool() // h1/c1 lacks the proper groups - require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) + require.ErrorIs(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { @@ -617,18 +617,18 @@ func TestFirewall_Drop3(t *testing.T) { cp := cert.NewCAPool() // c1 should pass because host match - require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h2, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) - assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule) // Test a remote address match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) - require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil)) } func TestFirewall_Drop3V6(t *testing.T) { @@ -666,7 +666,7 @@ func TestFirewall_Drop3V6(t *testing.T) { fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) cp := cert.NewCAPool() require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -708,12 +708,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) @@ -722,7 +722,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) @@ -731,7 +731,160 @@ func TestFirewall_DropConntrackReload(t *testing.T) { fw.rulesVersion = oldFw.rulesVersion + 1 // Drop outbound because conntrack doesn't match new ruleset - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule) +} + +func TestFirewall_ICMPPortBehavior(t *testing.T) { + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) + + network := netip.MustParsePrefix("1.2.3.4/24") + + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host1", + networks: []netip.Prefix{network}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + }, + InvertedGroups: map[string]struct{}{"default-group": {}}, + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c, + }, + vpnAddrs: []netip.Addr{network.Addr()}, + } + h.buildNetworks(myVpnNetworksTable, c.Certificate) + + cp := cert.NewCAPool() + + templ := firewall.Packet{ + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), + Protocol: firewall.ProtoICMP, + Fragment: false, + } + + t.Run("ICMP allowed", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) + + t.Run("nonzero ports", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) + }) + + t.Run("Any proto, some ports allowed", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + }) + + t.Run("nonzero ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + }) + + t.Run("nonzero, matching ports, still blocked", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 80 + p.RemotePort = 80 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule) + //now also allow outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + }) + }) + t.Run("Any proto, any port", func(t *testing.T) { + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + t.Run("zero ports, allowed", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0 + p.RemotePort = 0 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) + + t.Run("nonzero ports, allowed", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) + + t.Run("nonzero ports, allowed", func(t *testing.T) { + p := templ.Copy() + p.LocalPort = 0xabcd + p.RemotePort = 0x1234 + // Drop outbound + assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil)) + //now also allow outbound with a different ID + p.RemotePort++ + require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil)) + }) + }) + } func TestFirewall_DropIPSpoofing(t *testing.T) { @@ -777,7 +930,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { Protocol: firewall.ProtoUDP, Fragment: false, } - assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrInvalidRemoteIP) } func BenchmarkLookup(b *testing.B) { @@ -1184,7 +1337,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) { t.Helper() cp := cert.NewCAPool() resetConntrack(fw) - err := fw.Drop(c.p, true, c.h, cp, nil) + err := fw.Drop(c.p, nil, true, c.h, cp, nil) if c.err == nil { require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr) } else { diff --git a/hostmap.go b/hostmap.go index 7e2939e0..a451bd98 100644 --- a/hostmap.go +++ b/hostmap.go @@ -224,6 +224,7 @@ const ( NetworkTypeVPNPeer // NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks() NetworkTypeUnsafe + //todo consider NetworkTypeLinkLocal or NetworkTypeSNAT ) type HostInfo struct { @@ -277,6 +278,15 @@ type HostInfo struct { lastUsed time.Time } +func (i *HostInfo) HasOnlyV6Addresses() bool { + for _, vpnIp := range i.vpnAddrs { + if !vpnIp.Is6() { + return false + } + } + return true +} + type ViaSender struct { UdpAddr netip.AddrPort relayHI *HostInfo // relayHI is the host info object of the relay diff --git a/inside.go b/inside.go index 0d53f952..a8a0f86d 100644 --- a/inside.go +++ b/inside.go @@ -48,9 +48,24 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { - hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) - }) + var hostinfo *HostInfo + var ready bool + + snatMode := f.firewall.snatAddr.IsValid() && fwPacket.RemoteAddr == f.firewall.snatAddr + + if snatMode { + //todo unsnat happens here, would be nice to not + destVpnAddr := f.firewall.unSnat(packet, fwPacket, nil) //todo bail if we can't unsnat? + if destVpnAddr.IsValid() { + hostinfo, ready = f.getOrHandshakeNoRouting(destVpnAddr, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + } //otherwise, hostinfo will be nil + } else { //if we didn't need to unsnat + hostinfo, ready = f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { + hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + }) + } if hostinfo == nil { f.rejectInside(packet, out, q) @@ -66,10 +81,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet return } - dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, packet, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q) - } else { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { @@ -218,7 +232,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) + dropReason := f.firewall.Drop(*fp, nil, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { if f.l.Level >= logrus.DebugLevel { f.l.WithField("fwPacket", fp). diff --git a/main.go b/main.go index 17aaa548..15b3e677 100644 --- a/main.go +++ b/main.go @@ -135,7 +135,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg deviceFactory = overlay.NewDeviceFromConfig } - tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines) + cs := pki.getCertState() + tun, err = deviceFactory(c, l, cs.myVpnNetworks, cs.GetDefaultCertificate().UnsafeNetworks(), routines) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } diff --git a/outside.go b/outside.go index 172c3e83..786c7c7e 100644 --- a/outside.go +++ b/outside.go @@ -329,7 +329,7 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { switch proto { case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader: fp.Protocol = uint8(proto) - fp.RemotePort = 0 + fp.RemotePort = 0 //we don't attempt to parse ICMPv6 because we don't SNAT it fp.LocalPort = 0 fp.Fragment = false return nil @@ -434,22 +434,28 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { if incoming { fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16]) fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20]) - if fp.Fragment || fp.Protocol == firewall.ProtoICMP { + if fp.Fragment { fp.RemotePort = 0 fp.LocalPort = 0 + } else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier + fp.LocalPort = uint16(data[ihl+1]) //code } else { - fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) + fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port + fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port } } else { fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16]) fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20]) - if fp.Fragment || fp.Protocol == firewall.ProtoICMP { + if fp.Fragment { fp.RemotePort = 0 fp.LocalPort = 0 + } else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier + fp.LocalPort = uint16(data[ihl+1]) //code } else { - fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) + fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port + fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port } } @@ -494,7 +500,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(*fwPacket, out, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in diff --git a/overlay/tun.go b/overlay/tun.go index 3a61d186..35adbcf3 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -13,22 +13,22 @@ import ( const DefaultMTU = 1300 // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) return tun, nil default: - return newTun(c, l, vpnNetworks, routines > 1) + return newTun(c, l, vpnNetworks, unsafeNetworks, routines > 1) } } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { - return newTunFromFd(c, l, *fd, vpnNetworks) + return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) { + return newTunFromFd(c, l, *fd, vpnNetworks, unsafeNetworks) } } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index eddef882..f091772a 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -26,7 +26,7 @@ type tun struct { l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 128c2001..c9c3927e 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -79,7 +79,7 @@ type ifreqAlias6 struct { Lifetime addrLifetime } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 8d292263..939e0569 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -199,11 +199,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var fd int var err error diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 0ce01df8..85466d1e 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -28,11 +28,11 @@ type tun struct { l *logrus.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ vpnNetworks: vpnNetworks, diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index ea666f86..c3e1183f 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -26,14 +26,15 @@ import ( type tun struct { io.ReadWriteCloser - fd int - Device string - vpnNetworks []netip.Prefix - MaxMTU int - DefaultMTU int - TXQueueLen int - deviceIndex int - ioctlFd uintptr + fd int + Device string + vpnNetworks []netip.Prefix + unsafeNetworks []netip.Prefix + MaxMTU int + DefaultMTU int + TXQueueLen int + deviceIndex int + ioctlFd uintptr Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] @@ -71,10 +72,10 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks) if err != nil { return nil, err } @@ -84,7 +85,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net return t, nil } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -119,7 +120,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks) if err != nil { return nil, err } @@ -129,11 +130,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), vpnNetworks: vpnNetworks, + unsafeNetworks: unsafeNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), @@ -423,6 +425,27 @@ func (t *tun) setMTU() { } } +func (t *tun) setSnatRoute() error { + snataddr := netip.MustParsePrefix("169.254.55.96/32") //todo get this from elsewhere? Or maybe we should pick it, and feed it back out to the firewall? + dr := &net.IPNet{ + IP: snataddr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(snataddr.Bits(), snataddr.Addr().BitLen()), + } + + nr := netlink.Route{ + LinkIndex: t.deviceIndex, + Dst: dr, + //todo do we need these other options? + //MTU: t.DefaultMTU, + //AdvMSS: t.advMSS(Route{}), + Scope: unix.RT_SCOPE_LINK, + //Protocol: unix.RTPROT_KERNEL, + Table: unix.RT_TABLE_MAIN, + Type: unix.RTN_UNICAST, + } + return netlink.RouteReplace(&nr) +} + func (t *tun) setDefaultRoute(cidr netip.Prefix) error { dr := &net.IPNet{ IP: cidr.Masked().Addr().AsSlice(), @@ -499,6 +522,18 @@ func (t *tun) addRoutes(logErrors bool) error { } } + onlyV6Addresses := false + for _, n := range t.vpnNetworks { + if n.Addr().Is6() { + onlyV6Addresses = true + break + } + } + + if len(t.unsafeNetworks) != 0 && onlyV6Addresses { + return t.setSnatRoute() + } + return nil } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 2986c895..39336108 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -70,11 +70,11 @@ type tun struct { var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 9209b795..701d97dd 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -63,11 +63,11 @@ type tun struct { var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in openbsd") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var err error deviceName := c.GetString("tun.dev", "") diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 3477de3d..145eccb9 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -28,7 +28,7 @@ type TestTun struct { TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*TestTun, error) { _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err @@ -49,7 +49,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index b4d78b66..18ee533f 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -38,11 +38,11 @@ type winTun struct { tun *wintun.NativeTun } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*winTun, error) { err := checkWinTunExists() if err != nil { return nil, fmt.Errorf("can not load the wintun driver: %w", err) diff --git a/overlay/user.go b/overlay/user.go index 1f92d4e9..52fa0df7 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -9,7 +9,7 @@ import ( "github.com/slackhq/nebula/routing" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) { return NewUserDevice(vpnNetworks) } diff --git a/routing/balance.go b/routing/balance.go index 6f524970..22459113 100644 --- a/routing/balance.go +++ b/routing/balance.go @@ -12,7 +12,12 @@ import ( // - https://github.com/skeeto/hash-prospector // [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501 func hashPacket(p *firewall.Packet) int { - x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort) + var x uint32 + if p.Protocol == firewall.ProtoICMP { + x = 0 //Don't attempt to use ICMP's code/id/etc to balance + } else { + x = (uint32(p.LocalPort) << 16) | uint32(p.RemotePort) + } x ^= x >> 16 x *= 0x21f0aaad x ^= x >> 15 diff --git a/snat.go b/snat.go new file mode 100644 index 00000000..3164b641 --- /dev/null +++ b/snat.go @@ -0,0 +1,91 @@ +package nebula + +import ( + "encoding/binary" + "net/netip" +) + +func recalcIPv4Checksum(data []byte, oldSrcIP netip.Addr, newSrcIP netip.Addr) { + oldChecksum := binary.BigEndian.Uint16(data[10:12]) + //because of how checksums work, we can re-use this function + checksum := calcNewTransportChecksum(oldChecksum, oldSrcIP, 0, newSrcIP, 0) + binary.BigEndian.PutUint16(data[10:12], checksum) +} + +func calcNewTransportChecksum(oldChecksum uint16, oldSrcIP netip.Addr, oldSrcPort uint16, newSrcIP netip.Addr, newSrcPort uint16) uint16 { + oldIP := binary.BigEndian.Uint32(oldSrcIP.AsSlice()) + newIP := binary.BigEndian.Uint32(newSrcIP.AsSlice()) + + // Start with inverted checksum + checksum := uint32(^oldChecksum) + + // Subtract old IP (as two 16-bit words) + checksum += uint32(^uint16(oldIP >> 16)) + checksum += uint32(^uint16(oldIP & 0xFFFF)) + + // Subtract old port + checksum += uint32(^oldSrcPort) + + // Add new IP (as two 16-bit words) + checksum += uint32(newIP >> 16) + checksum += uint32(newIP & 0xFFFF) + + // Add new port + checksum += uint32(newSrcPort) + + // Fold carries + for checksum > 0xFFFF { + checksum = (checksum & 0xFFFF) + (checksum >> 16) + } + + // Return ones' complement + return ^uint16(checksum) +} + +func recalcV4TransportChecksum(offsetInsideHeader int, data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) { + ipHeaderOffset := int(data[0]&0x0F) * 4 + offset := ipHeaderOffset + offsetInsideHeader + oldcsum := binary.BigEndian.Uint16(data[offset : offset+2]) + checksum := calcNewTransportChecksum(oldcsum, oldSrcIP.Addr(), oldSrcIP.Port(), newSrcIP.Addr(), newSrcIP.Port()) + binary.BigEndian.PutUint16(data[offset:offset+2], checksum) +} + +func recalcUDPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) { + const offsetInsideHeader = 6 + recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP) +} + +func recalcTCPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) { + const offsetInsideHeader = 16 + recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP) +} + +func calcNewICMPChecksum(oldChecksum uint16, oldCode uint16, newCode uint16, oldID uint16, newID uint16) uint16 { + // Start with inverted checksum + checksum := uint32(^oldChecksum) + + // Subtract old stuff + checksum += uint32(^oldCode) + checksum += uint32(^oldID) + + // Add new stuff + checksum += uint32(newCode) + checksum += uint32(newID) + + // Fold carries + for checksum > 0xFFFF { + checksum = (checksum & 0xFFFF) + (checksum >> 16) + } + + // Return ones' complement + return ^uint16(checksum) +} + +func recalcICMPv4Checksum(data []byte, oldCode uint16, newCode uint16, oldID uint16, newID uint16) { + const offsetInsideHeader = 2 + ipHeaderOffset := int(data[0]&0x0F) * 4 + offset := ipHeaderOffset + offsetInsideHeader + oldChecksum := binary.BigEndian.Uint16(data[offset : offset+2]) + checksum := calcNewICMPChecksum(oldChecksum, oldCode, newCode, oldID, newID) + binary.BigEndian.PutUint16(data[offset:offset+2], checksum) +}