Compare commits

...

12 commits

Author SHA1 Message Date
Nate Brown
aaf3fd2355 v1.10 changelog 2025-12-04 00:09:39 -06:00
Nate Brown
fe6b016d60 Pull in v1.9.5-v1.9.7 CHANGELOG 2025-12-03 20:38:29 -06:00
Nate Brown
56067afca2
Stab at better logging when a relay is being used (#1533)
Some checks are pending
gofmt / Run gofmt (push) Waiting to run
smoke-extra / Run extra smoke tests (push) Waiting to run
smoke / Run multi node smoke test (push) Waiting to run
Build and test / Build all and test on ubuntu-linux (push) Waiting to run
Build and test / Build and test on linux with boringcrypto (push) Waiting to run
Build and test / Build and test on linux with pkcs11 (push) Waiting to run
Build and test / Build and test on macos-latest (push) Waiting to run
Build and test / Build and test on windows-latest (push) Waiting to run
2025-12-03 17:48:29 -06:00
Nate Brown
64f202fa17
Make 0.0.0.0/0 and ::/0 not mean any address family, add any for that (#1538)
Some checks failed
gofmt / Run gofmt (push) Has been cancelled
smoke-extra / Run extra smoke tests (push) Has been cancelled
smoke / Run multi node smoke test (push) Has been cancelled
Build and test / Build all and test on ubuntu-linux (push) Has been cancelled
Build and test / Build and test on linux with boringcrypto (push) Has been cancelled
Build and test / Build and test on linux with pkcs11 (push) Has been cancelled
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
2025-11-21 13:46:36 -06:00
Jack Doan
6d7cf611c9
improve nebula-cert sign version auto-select (#1535)
Some checks are pending
gofmt / Run gofmt (push) Waiting to run
smoke-extra / Run extra smoke tests (push) Waiting to run
smoke / Run multi node smoke test (push) Waiting to run
Build and test / Build all and test on ubuntu-linux (push) Waiting to run
Build and test / Build and test on linux with boringcrypto (push) Waiting to run
Build and test / Build and test on linux with pkcs11 (push) Waiting to run
Build and test / Build and test on macos-latest (push) Waiting to run
Build and test / Build and test on windows-latest (push) Waiting to run
2025-11-20 13:27:27 -06:00
Nate Brown
83ae8077f5
No need to clear counter 0 (#1537) 2025-11-20 13:22:58 -06:00
Bryan Lee
12cf348c80
feat: support via gateway for v6 multihop for v4 routes (#1521)
Some checks are pending
gofmt / Run gofmt (push) Waiting to run
smoke-extra / Run extra smoke tests (push) Waiting to run
smoke / Run multi node smoke test (push) Waiting to run
Build and test / Build all and test on ubuntu-linux (push) Waiting to run
Build and test / Build and test on linux with boringcrypto (push) Waiting to run
Build and test / Build and test on linux with pkcs11 (push) Waiting to run
Build and test / Build and test on macos-latest (push) Waiting to run
Build and test / Build and test on windows-latest (push) Waiting to run
Co-authored-by: Nate Brown <nbrown.us@gmail.com>
2025-11-19 22:21:03 -06:00
dependabot[bot]
a5ee928990
Bump golang.org/x/crypto in the golang-x-dependencies group (#1536)
Some checks are pending
gofmt / Run gofmt (push) Waiting to run
smoke-extra / Run extra smoke tests (push) Waiting to run
smoke / Run multi node smoke test (push) Waiting to run
Build and test / Build all and test on ubuntu-linux (push) Waiting to run
Build and test / Build and test on linux with boringcrypto (push) Waiting to run
Build and test / Build and test on linux with pkcs11 (push) Waiting to run
Build and test / Build and test on macos-latest (push) Waiting to run
Build and test / Build and test on windows-latest (push) Waiting to run
Bumps the golang-x-dependencies group with 1 update: [golang.org/x/crypto](https://github.com/golang/crypto).


Updates `golang.org/x/crypto` from 0.44.0 to 0.45.0
- [Commits](https://github.com/golang/crypto/compare/v0.44.0...v0.45.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.45.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-19 15:48:53 -06:00
Nate Brown
7aff313a17
Relax the restriction on routines from the config (#1531) 2025-11-19 13:10:11 -06:00
Jack Doan
297767b2e3
warn user if they configure a firewall rule that will allow way more traffic than you might expect (#1513)
Some checks are pending
gofmt / Run gofmt (push) Waiting to run
smoke-extra / Run extra smoke tests (push) Waiting to run
smoke / Run multi node smoke test (push) Waiting to run
Build and test / Build all and test on ubuntu-linux (push) Waiting to run
Build and test / Build and test on linux with boringcrypto (push) Waiting to run
Build and test / Build and test on linux with pkcs11 (push) Waiting to run
Build and test / Build and test on macos-latest (push) Waiting to run
Build and test / Build and test on windows-latest (push) Waiting to run
* warn user if they accidentally configure a firewall rule that will allow way more traffic than you might expect

* add groups-contains-any warning
2025-11-19 11:06:34 -06:00
Nate Brown
99faab505c
Fix a potential bug with udp ipv4 only on darwin (#1532) 2025-11-19 09:56:58 -06:00
dependabot[bot]
584c2668b3
Bump golang.org/x/net in the golang-x-dependencies group (#1530)
Some checks are pending
gofmt / Run gofmt (push) Waiting to run
smoke-extra / Run extra smoke tests (push) Waiting to run
smoke / Run multi node smoke test (push) Waiting to run
Build and test / Build all and test on ubuntu-linux (push) Waiting to run
Build and test / Build and test on linux with boringcrypto (push) Waiting to run
Build and test / Build and test on linux with pkcs11 (push) Waiting to run
Build and test / Build and test on macos-latest (push) Waiting to run
Build and test / Build and test on windows-latest (push) Waiting to run
Bumps the golang-x-dependencies group with 1 update: [golang.org/x/net](https://github.com/golang/net).


Updates `golang.org/x/net` from 0.46.0 to 0.47.0
- [Commits](https://github.com/golang/net/compare/v0.46.0...v0.47.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-version: 0.47.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: golang-x-dependencies
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-18 21:25:03 -06:00
36 changed files with 690 additions and 316 deletions

View file

@ -7,12 +7,84 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [1.10.0] - 2025-12-04
### Added
- Support for ipv6 and multiple ipv4/6 addresses in the overlay.
A new v2 ASN.1 based certificate format.
Certificates now have a unified interface for external implementations.
(#1212, #1216, #1345, #1359, #1381, #1419, #1464, #1466, #1451, #1476, #1467, #1481, #1399, #1488, #1492, #1495, #1468, #1521, #1535, #1538)
**TODO: External documentation link!**
- Add the ability to mark packets on linux to better target nebula packets in iptables/nftables. (#1331)
- Add ECMP support for `unsafe_routes`. (#1332)
- PKCS11 support for P256 keys when built with `pkcs11` tag (#1153, #1482)
### Changed
- Improve logging when a relay is in use on an inbound packet. (#1533)
- Avoid fatal errors if `rountines` is > 1 on systems that don't support more than 1 routine. (#1531)
- Log a warning if a firewall rule contains an `any` that negates a more restrictive filter. (#1513)
- Accept encrypted CA passphrase from an environment variable. (#1421)
- Allow handshaking with any trusted remote. (#1509)
- Log only the count of blocklisted certificate fingerprints instead of the entire list. (#1525)
- Don't fatal when the ssh server is unable to be configured successfully. (#1520)
- Update to build against go v1.25. (#1483)
- `default_local_cidr_any` now defaults to false, meaning that any firewall rule
intended to target an `unsafe_routes` entry must explicitly declare it via the
`local_cidr` field. This is almost always the intended behavior. This flag is
deprecated and will be removed in a future release.
deprecated and will be removed in a future release. (#1373)
- Allow projects using `nebula` as a library with userspace networking to configure the `logger` and build version. (#1239)
- Upgrade to `yaml.v3`. (#1148, #1371, #1438, #1478)
### Fixed
- Fix a potential bug with udp ipv4 only on darwin. (#1532)
- Improve lost packet statistics. (#1441, #1537)
- Honor `remote_allow_list` in hole punch response. (#1186)
- Fix a panic when `tun.use_system_route_table` is `true` and a route lacks a destination. (#1437)
- Fix an issue when `tun.use_system_route_table: true` could result in heavy CPU utilization when many thousands of routes
are present. (#1326)
- Fix tests for 32 bit machines. (#1394)
- Fix a possible 32bit integer underflow in config handling. (#1353)
- Fix moving a udp address from one vpn address to another in the `static_host_map`
which could cause rapid re-handshaking with an incorrect remote. (#1259)
- Improve smoke tests in environments where the docker network is not the default. (#1347)
## [1.9.7] - 2025-10-10
### Security
- Fix an issue where Nebula could incorrectly accept and process a packet from an erroneous source IP when the sender's
certificate is configured with unsafe_routes (cert v1/v2) or multiple IPs (cert v2). (#1494)
### Changed
- Disable sending `recv_error` messages when a packet is received outside the allowable counter window. (#1459)
- Improve error messages and remove some unnecessary fatal conditions in the Windows and generic udp listener. (#1453)
## [1.9.6] - 2025-7-15
### Added
- Support dropping inactive tunnels. This is disabled by default in this release but can be enabled with `tunnels.drop_inactive`. See example config for more details. (#1413)
### Fixed
- Fix Darwin freeze due to presence of some Network Extensions (#1426)
- Ensure the same relay tunnel is always used when multiple relay tunnels are present (#1422)
- Fix Windows freeze due to ICMP error handling (#1412)
- Fix relay migration panic (#1403)
## [1.9.5] - 2024-12-05
### Added
- Gracefully ignore v2 certificates. (#1282)
### Fixed
- Fix relays that refuse to re-establish after one of the remote tunnel pairs breaks. (#1277)
## [1.9.4] - 2024-09-09
@ -671,7 +743,11 @@ created.)
- Initial public release.
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.4...HEAD
[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.0...HEAD
[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0
[1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7
[1.9.6]: https://github.com/slackhq/nebula/releases/tag/v1.9.6
[1.9.5]: https://github.com/slackhq/nebula/releases/tag/v1.9.5
[1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4
[1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3
[1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2

View file

@ -43,7 +43,7 @@ type signFlags struct {
func newSignFlags() *signFlags {
sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
sf.set.Usage = func() {}
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA")
sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
@ -167,6 +167,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
return fmt.Errorf("ca certificate is expired")
}
if version == 0 {
version = caCert.Version()
}
// if no duration is given, expire one second before the root expires
if *sf.duration <= 0 {
*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
@ -279,21 +283,19 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
notBefore := time.Now()
notAfter := notBefore.Add(*sf.duration)
if version == 0 || version == cert.Version1 {
// Make sure we at least have an ip
switch version {
case cert.Version1:
// Make sure we have only one ipv4 address
if len(v4Networks) != 1 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
}
if version == cert.Version1 {
// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
if len(v6Networks) > 0 {
return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses")
}
if len(v6UnsafeNetworks) > 0 {
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
}
return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
}
t := &cert.TBSCertificate{
@ -323,9 +325,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
crts = append(crts, nc)
}
if version == 0 || version == cert.Version2 {
case cert.Version2:
t := &cert.TBSCertificate{
Version: cert.Version2,
Name: *sf.name,
@ -353,6 +354,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
}
crts = append(crts, nc)
default:
// this should be unreachable
return fmt.Errorf("invalid version: %d", version)
}
if !isP11 && *sf.inPubPath == "" {

View file

@ -55,7 +55,7 @@ func Test_signHelp(t *testing.T) {
" -unsafe-networks string\n"+
" \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
" -version uint\n"+
" \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
" \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n",
ob.String(),
)
}
@ -204,7 +204,7 @@ func Test_signCert(t *testing.T) {
ob.Reset()
eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
assert.Empty(t, ob.String())
assert.Empty(t, eb.String())

View file

@ -50,11 +50,6 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
}
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
b.Update(l, 0)
hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: ncs,
Random: rand.Reader,
@ -74,7 +69,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i
ci := &ConnectionState{
H: hs,
initiator: initiator,
window: b,
window: NewBits(ReplayWindow),
myCert: crt,
}
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.

View file

@ -25,11 +25,12 @@ import (
func BenchmarkHotPath(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
// Put their info in our lighthouse
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
@ -38,6 +39,9 @@ func BenchmarkHotPath(b *testing.B) {
r := router.NewR(b, myControl, theirControl)
r.CancelFlowLogs()
assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
@ -47,6 +51,39 @@ func BenchmarkHotPath(b *testing.B) {
theirControl.Stop()
}
func BenchmarkHotPathRelay(b *testing.B) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
// Teach my how to get to the relay and that their can be reached via the relay
myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
// Build a router so we don't have to reason who gets which packet
r := router.NewR(b, myControl, relayControl, theirControl)
r.CancelFlowLogs()
// Start the servers
myControl.Start()
relayControl.Start()
theirControl.Start()
assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
b.ResetTimer()
for n := 0; n < b.N; n++ {
myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
_ = r.RouteForAllUntilTxTun(theirControl)
}
myControl.Stop()
theirControl.Stop()
relayControl.Stop()
}
func TestGoodHandshake(t *testing.T) {
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)

View file

@ -292,7 +292,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
}
}
func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
// Send a packet from them to me
controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
bPacket := r.RouteForAllUntilTxTun(controlA)
@ -325,7 +325,7 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
}
func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
if toIp.Is6() {
assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
} else {
@ -333,7 +333,7 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
}
}
func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
assert.NotNil(t, v6, "No ipv6 data found")
@ -352,7 +352,7 @@ func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
assert.Equal(t, expected, data.Payload(), "Data was incorrect")
}
func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
assert.NotNil(t, v4, "No ipv4 data found")

View file

@ -383,8 +383,9 @@ firewall:
# host: `any` or a literal hostname, ie `test-host`
# group: `any` or a literal group name, ie `default-group`
# groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes.
# cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
# local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address.
# This can be used to filter destinations when using unsafe_routes.
# By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true.
# If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case.
# ca_name: An issuing CA name

View file

@ -8,6 +8,7 @@ import (
"hash/fnv"
"net/netip"
"reflect"
"slices"
"strconv"
"strings"
"sync"
@ -22,7 +23,7 @@ import (
)
type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
}
type conn struct {
@ -247,22 +248,11 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
}
// AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip.IsValid() {
sIp = ip.String()
}
lIp := ""
if localIp.IsValid() {
lIp = localIp.String()
}
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
// We need this rule string because we generate a hash. Removing this will break firewall reload.
ruleString := fmt.Sprintf(
"incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s",
incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha,
incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha,
)
f.rules += ruleString + "\n"
@ -270,7 +260,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
if !incoming {
direction = "outgoing"
}
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}).
f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}).
Info("Firewall rule added")
var (
@ -297,7 +287,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
return fmt.Errorf("unknown protocol %v", proto)
}
return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha)
}
// GetRuleHash returns a hash representation of all inbound and outbound rules
@ -337,7 +327,6 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
}
for i, t := range rs {
var groups []string
r, err := convertRule(l, t, table, i)
if err != nil {
return fmt.Errorf("%s rule #%v; %s", table, i, err)
@ -347,23 +336,10 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
}
if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
if r.Host == "" && len(r.Groups) == 0 && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" {
return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i)
}
if len(r.Groups) > 0 {
groups = r.Groups
}
if r.Group != "" {
// Check if we have both groups and group provided in the rule config
if len(groups) > 0 {
return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
}
groups = []string{r.Group}
}
var sPort, errPort string
if r.Code != "" {
errPort = "code"
@ -392,23 +368,25 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
}
var cidr netip.Prefix
if r.Cidr != "" {
cidr, err = netip.ParsePrefix(r.Cidr)
if r.Cidr != "" && r.Cidr != "any" {
_, err = netip.ParsePrefix(r.Cidr)
if err != nil {
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
}
}
var localCidr netip.Prefix
if r.LocalCidr != "" {
localCidr, err = netip.ParsePrefix(r.LocalCidr)
if r.LocalCidr != "" && r.LocalCidr != "any" {
_, err = netip.ParsePrefix(r.LocalCidr)
if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
}
}
err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha)
if warning := r.sanity(); warning != nil {
l.Warnf("%s rule #%v; %s", table, i, warning)
}
err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha)
if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
}
@ -655,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC
return false
}
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error {
if startPort > endPort {
return fmt.Errorf("start port was lower than end port")
}
@ -668,7 +646,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou
}
}
if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
if err := fp[i].addRule(f, groups, host, cidr, localCidr, caName, caSha); err != nil {
return err
}
}
@ -699,7 +677,7 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
return fp[firewall.PortAny].match(p, c, caPool)
}
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, localCidr, caName, caSha string) error {
fr := func() *FirewallRule {
return &FirewallRule{
Hosts: make(map[string]*firewallLocalCIDR),
@ -713,14 +691,14 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
fc.Any = fr()
}
return fc.Any.addRule(f, groups, host, ip, localIp)
return fc.Any.addRule(f, groups, host, cidr, localCidr)
}
if caSha != "" {
if _, ok := fc.CAShas[caSha]; !ok {
fc.CAShas[caSha] = fr()
}
err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
err := fc.CAShas[caSha].addRule(f, groups, host, cidr, localCidr)
if err != nil {
return err
}
@ -730,7 +708,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc
if _, ok := fc.CANames[caName]; !ok {
fc.CANames[caName] = fr()
}
err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
err := fc.CANames[caName].addRule(f, groups, host, cidr, localCidr)
if err != nil {
return err
}
@ -762,24 +740,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool
return fc.CANames[s.Certificate.Name()].match(p, c)
}
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localCidr string) error {
flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{
LocalCIDR: new(bart.Lite),
}
}
if fr.isAny(groups, host, ip) {
if fr.isAny(groups, host, cidr) {
if fr.Any == nil {
fr.Any = flc()
}
return fr.Any.addRule(f, localCIDR)
return fr.Any.addRule(f, localCidr)
}
if len(groups) > 0 {
nlc := flc()
err := nlc.addRule(f, localCIDR)
err := nlc.addRule(f, localCidr)
if err != nil {
return err
}
@ -795,30 +773,34 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, l
if nlc == nil {
nlc = flc()
}
err := nlc.addRule(f, localCIDR)
err := nlc.addRule(f, localCidr)
if err != nil {
return err
}
fr.Hosts[host] = nlc
}
if ip.IsValid() {
nlc, _ := fr.CIDR.Get(ip)
if nlc == nil {
nlc = flc()
}
err := nlc.addRule(f, localCIDR)
if cidr != "" {
c, err := netip.ParsePrefix(cidr)
if err != nil {
return err
}
fr.CIDR.Insert(ip, nlc)
nlc, _ := fr.CIDR.Get(c)
if nlc == nil {
nlc = flc()
}
err = nlc.addRule(f, localCidr)
if err != nil {
return err
}
fr.CIDR.Insert(c, nlc)
}
return nil
}
func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
if len(groups) == 0 && host == "" && !ip.IsValid() {
func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool {
if len(groups) == 0 && host == "" && cidr == "" {
return true
}
@ -832,7 +814,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) boo
return true
}
if ip.IsValid() && ip.Bits() == 0 {
if cidr == "any" {
return true
}
@ -884,8 +866,13 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool
return false
}
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
if !localIp.IsValid() {
func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error {
if localCidr == "any" {
flc.Any = true
return nil
}
if localCidr == "" {
if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny {
flc.Any = true
return nil
@ -896,12 +883,13 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
}
return nil
} else if localIp.Bits() == 0 {
flc.Any = true
return nil
}
flc.LocalCIDR.Insert(localIp)
c, err := netip.ParsePrefix(localCidr)
if err != nil {
return err
}
flc.LocalCIDR.Insert(c)
return nil
}
@ -922,7 +910,6 @@ type rule struct {
Code string
Proto string
Host string
Group string
Groups []string
Cidr string
LocalCidr string
@ -964,7 +951,8 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
m["group"] = v[0]
}
r.Group = toString("group", m)
singleGroup := toString("group", m)
if rg, ok := m["groups"]; ok {
switch reflect.TypeOf(rg).Kind() {
@ -981,9 +969,60 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) {
}
}
//flatten group vs groups
if singleGroup != "" {
// Check if we have both groups and group provided in the rule config
if len(r.Groups) > 0 {
return r, fmt.Errorf("only one of group or groups should be defined, both provided")
}
r.Groups = []string{singleGroup}
}
return r, nil
}
// sanity returns an error if the rule would be evaluated in a way that would short-circuit a configured check on a wildcard value
// rules are evaluated as "port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND local_cidr"
func (r *rule) sanity() error {
//port, proto, local_cidr are AND, no need to check here
//ca_sha and ca_name don't have a wildcard value, no need to check here
groupsEmpty := len(r.Groups) == 0
hostEmpty := r.Host == ""
cidrEmpty := r.Cidr == ""
if (groupsEmpty && hostEmpty && cidrEmpty) == true {
return nil //no content!
}
groupsHasAny := slices.Contains(r.Groups, "any")
if groupsHasAny && len(r.Groups) > 1 {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the other groups specified", r.Groups)
}
if r.Host == "any" {
if !groupsEmpty {
return fmt.Errorf("groups specified as %s, but host=any will match any host, regardless of groups", r.Groups)
}
if !cidrEmpty {
return fmt.Errorf("cidr specified as %s, but host=any will match any host, regardless of cidr", r.Cidr)
}
}
if groupsHasAny {
if !hostEmpty && r.Host != "any" {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified host %s", r.Groups, r.Host)
}
if !cidrEmpty {
return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified cidr %s", r.Groups, r.Cidr)
}
}
//todo alert on cidr-any
return nil
}
func parsePort(s string) (startPort, endPort int32, err error) {
if s == "any" {
startPort = firewall.PortAny

View file

@ -73,78 +73,106 @@ func TestFirewall_AddRule(t *testing.T) {
ti6, err := netip.ParsePrefix("fd12::34/128")
require.NoError(t, err)
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", ""))
// An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.True(t, table.Any)
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
}
func TestFirewall_Drop(t *testing.T) {
@ -180,7 +208,7 @@ func TestFirewall_Drop(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Drop outbound
@ -199,28 +227,28 @@ func TestFirewall_Drop(t *testing.T) {
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
@ -259,7 +287,7 @@ func TestFirewall_DropV6(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Drop outbound
@ -278,28 +306,28 @@ func TestFirewall_DropV6(t *testing.T) {
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
@ -310,12 +338,12 @@ func BenchmarkFirewallTable_match(b *testing.B) {
}
pfix := netip.MustParsePrefix("172.1.1.1/32")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "")
pfix6 := netip.MustParsePrefix("fd11::11/128")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "")
_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "")
cp := cert.NewCAPool()
b.Run("fail on proto", func(b *testing.B) {
@ -504,7 +532,7 @@ func TestFirewall_Drop2(t *testing.T) {
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
@ -584,8 +612,8 @@ func TestFirewall_Drop3(t *testing.T) {
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
cp := cert.NewCAPool()
// c1 should pass because host match
@ -599,7 +627,7 @@ func TestFirewall_Drop3(t *testing.T) {
// Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
}
@ -637,7 +665,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
// Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
}
@ -676,7 +704,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Drop outbound
@ -689,7 +717,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@ -698,7 +726,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@ -737,7 +765,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Packet spoofed by `c1`. Note that the remote addr is not a valid one.
@ -931,28 +959,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf := &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding udp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding icmp rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding any rule
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
// Test adding rule with cidr
cidr := netip.MustParsePrefix("10.0.0.0/8")
@ -960,14 +988,14 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
// Test adding rule with local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall)
// Test adding rule with cidr ipv6
cidr6 := netip.MustParsePrefix("fd00::/8")
@ -975,49 +1003,75 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
// Test adding rule with any cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
// Test adding rule with junk cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with local_cidr ipv6
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
// Test adding rule with any local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
// Test adding rule with junk local_cidr
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
// Test adding rule with ca_sha
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
// Test single group
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test single groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
// Test multiple AND groups
conf = config.NewC(l)
mf = &mockFirewall{}
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
// Test Add error
conf = config.NewC(l)
@ -1040,7 +1094,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
require.NoError(t, err)
assert.Equal(t, "group1", r.Group)
assert.Equal(t, []string{"group1"}, r.Groups)
// Ensure group array of > 1 is errord
ob.Reset()
@ -1060,7 +1114,63 @@ func TestFirewall_convertRule(t *testing.T) {
r, err = convertRule(l, c, "test", 1)
require.NoError(t, err)
assert.Equal(t, "group1", r.Group)
assert.Equal(t, []string{"group1"}, r.Groups)
}
func TestFirewall_convertRuleSanity(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
noWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"host": "bob"},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range noWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
}
yesWarningPlease := []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
c["host"] = "any"
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
//reset the list
yesWarningPlease = []map[string]any{
{"group": "group1"},
{"groups": []any{"group2"}},
{"cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "host": "bob"},
{"cidr": "1.1.1.1/1", "host": "bob"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
{"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
}
for _, c := range yesWarningPlease {
r, err := convertRule(l, c, "test", 1)
require.NoError(t, err)
r.Groups = append(r.Groups, "any")
err = r.sanity()
require.Error(t, err, "I wanted a warning: %+v", c)
}
}
type testcase struct {
@ -1141,7 +1251,7 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable.Insert(prefix)
}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
return testsetup{
c: c,
@ -1222,7 +1332,7 @@ func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
tc.err = ErrNoMatchingRule
tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", ""))
require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
tc.err = nil
tc.Test(t, unsafeSetup.fw) //should pass
})
@ -1235,8 +1345,8 @@ type addRuleCall struct {
endPort int32
groups []string
host string
ip netip.Prefix
localIp netip.Prefix
ip string
localIp string
caName string
caSha string
}
@ -1246,7 +1356,7 @@ type mockFirewall struct {
nextCallReturn error
}
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error {
mf.lastCall = addRuleCall{
incoming: incoming,
proto: proto,

4
go.mod
View file

@ -23,9 +23,9 @@ require (
github.com/stretchr/testify v1.11.1
github.com/vishvananda/netlink v1.3.1
go.yaml.in/yaml/v3 v3.0.4
golang.org/x/crypto v0.44.0
golang.org/x/crypto v0.45.0
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090
golang.org/x/net v0.46.0
golang.org/x/net v0.47.0
golang.org/x/sync v0.18.0
golang.org/x/sys v0.38.0
golang.org/x/term v0.37.0

8
go.sum
View file

@ -162,8 +162,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View file

@ -99,11 +99,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return true
}
func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) {
cs := f.pki.getCertState()
crt := cs.GetDefaultCertificate()
if crt == nil {
f.l.WithField("udpAddr", addr).
f.l.WithField("from", via).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.initiatingVersion).
Error("Unable to handshake with host because no certificate is available")
@ -112,7 +112,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
return
@ -123,7 +123,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
return
@ -132,7 +132,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
return
@ -140,7 +140,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate")
return
@ -153,7 +153,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
fp = "<error generating certificate fingerprint>"
}
e := f.l.WithError(err).WithField("udpAddr", addr).
e := f.l.WithError(err).WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVpnNetworks", rc.Networks()).
WithField("certFingerprint", fp)
@ -172,7 +172,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if myCertOtherVersion == nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithError(err).WithFields(m{
"udpAddr": addr,
"from": via,
"handshake": m{"stage": 1, "style": "ix_psk0"},
"cert": remoteCert,
}).Debug("Might be unable to handshake with host due to missing certificate version")
@ -184,7 +184,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("cert", remoteCert).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("No networks in certificate")
@ -201,7 +201,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks {
if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@ -215,18 +215,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}
}
if addr.IsValid() {
// addr can be invalid when the tunnel is being relayed.
if !via.IsRelayed {
// We only want to apply the remote allow list for direct tunnels here
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
}
myIndex, err := generateIndex(f.l)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@ -251,7 +251,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
msgRxL := f.l.WithFields(m{
"vpnAddrs": vpnAddrs,
"udpAddr": addr,
"from": via,
"certName": certName,
"certVersion": certVersion,
"fingerprint": fingerprint,
@ -283,7 +283,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@ -295,7 +295,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@ -303,7 +303,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@ -329,7 +329,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
hostinfo.SetRemote(addr)
if !via.IsRelayed {
hostinfo.SetRemote(via.UdpAddr)
}
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
@ -337,7 +339,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
switch err {
case ErrAlreadySeen:
// Update remote if preferred
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
if existing.SetRemoteIfPreferred(f.hostMap, via) {
// Send a test packet to ensure the other side has also switched to
// the preferred remote
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
@ -345,21 +347,21 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if addr.IsValid() {
err := f.outside.WriteTo(msg, addr)
if !via.IsRelayed {
err := f.outside.WriteTo(msg, via.UdpAddr)
if err != nil {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
WithError(err).Error("Failed to send handshake message")
} else {
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
Info("Handshake message sent")
}
return
} else {
if via == nil {
f.l.Error("Handshake send failed: both addr and via are nil.")
if via.relay == nil {
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@ -371,7 +373,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
}
case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("oldHandshakeTime", existing.lastHandshakeTime).
@ -387,7 +389,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return
case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@ -400,7 +402,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
@ -414,30 +416,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
// Do the send
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if addr.IsValid() {
err = f.outside.WriteTo(msg, addr)
if !via.IsRelayed {
err = f.outside.WriteTo(msg, via.UdpAddr)
log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
if err != nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake")
log.WithError(err).Error("Failed to send handshake")
} else {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake message sent")
log.Info("Handshake message sent")
}
} else {
if via == nil {
f.l.Error("Handshake send failed: both addr and via are nil.")
if via.relay == nil {
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@ -462,7 +457,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return
}
func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
if hh == nil {
// Nothing here to tear down, got a bogus stage 2 packet
return true
@ -472,10 +467,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
defer hh.Unlock()
hostinfo := hh.hostinfo
if addr.IsValid() {
if !via.IsRelayed {
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
return false
}
}
@ -483,7 +478,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Error("Failed to call noise.ReadMessage")
@ -492,7 +487,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// near future
return false
} else if dKey == nil || eKey == nil {
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Error("Noise did not arrive at a key")
@ -504,7 +499,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
@ -513,7 +508,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
Info("Handshake did not contain a certificate")
@ -527,7 +522,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
fp = "<error generating certificate fingerprint>"
}
e := f.l.WithError(err).WithField("udpAddr", addr).
e := f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("certFingerprint", fp).
@ -542,7 +537,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("udpAddr", addr).
f.l.WithError(err).WithField("from", via).
WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("cert", remoteCert).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@ -565,8 +560,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.eKey = NewNebulaCipherState(eKey)
// Make sure the current udpAddr being used is set for responding
if addr.IsValid() {
hostinfo.SetRemote(addr)
if !via.IsRelayed {
hostinfo.SetRemote(via.UdpAddr)
} else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
}
@ -588,7 +583,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
// Ensure the right host responded
if !correctHostResponded {
f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
WithField("udpAddr", addr).
WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@ -602,7 +597,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(addr)
newHH.hostinfo.remotes.BlockRemote(via)
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
WithField("vpnNetworks", vpnNetworks).
@ -625,7 +620,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
WithField("certName", certName).
WithField("certVersion", certVersion).
WithField("fingerprint", fingerprint).

View file

@ -136,11 +136,11 @@ func (hm *HandshakeManager) Run(ctx context.Context) {
}
}
func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) {
// First remote allow list check before we know the vpnIp
if addr.IsValid() {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
if !via.IsRelayed {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
}
@ -149,11 +149,11 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
case header.HandshakeIXPSK0:
switch h.MessageCounter {
case 1:
ixHandshakeStage1(hm.f, addr, via, packet, h)
ixHandshakeStage1(hm.f, via, packet, h)
case 2:
newHostinfo := hm.queryIndex(h.RemoteIndex)
tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h)
tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h)
if tearDown && newHostinfo != nil {
hm.DeleteHostInfo(newHostinfo.hostinfo)
}

View file

@ -1,7 +1,9 @@
package nebula
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"slices"
@ -276,9 +278,25 @@ type HostInfo struct {
}
type ViaSender struct {
UdpAddr netip.AddrPort
relayHI *HostInfo // relayHI is the host info object of the relay
remoteIdx uint32 // remoteIdx is the index included in the header of the received packet
relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us.
IsRelayed bool // IsRelayed is true if the packet was sent through a relay
}
func (v ViaSender) String() string {
if v.IsRelayed {
return fmt.Sprintf("%s (relayed)", v.UdpAddr)
}
return v.UdpAddr.String()
}
func (v ViaSender) MarshalJSON() ([]byte, error) {
if v.IsRelayed {
return json.Marshal(m{"relay": v.UdpAddr})
}
return json.Marshal(m{"direct": v.UdpAddr})
}
type cachedPacket struct {
@ -694,6 +712,7 @@ func (i *HostInfo) GetCert() *cert.CachedCertificate {
return nil
}
// TODO: Maybe use ViaSender here?
func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// We copy here because we likely got this remote from a source that reuses the object
if i.remote != remote {
@ -704,14 +723,14 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
// SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
// time on the HostInfo will also be updated.
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
if !newRemote.IsValid() {
// relays have nil udp Addrs
func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via ViaSender) bool {
if via.IsRelayed {
return false
}
currentRemote := i.remote
if !currentRemote.IsValid() {
i.SetRemote(newRemote)
i.SetRemote(via.UdpAddr)
return true
}
@ -724,7 +743,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
return false
}
if l.Contains(newRemote.Addr()) {
if l.Contains(via.UdpAddr.Addr()) {
newIsPreferred = true
}
}
@ -734,7 +753,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
i.lastRoam = time.Now()
i.lastRoamRemote = currentRemote
i.SetRemote(newRemote)
i.SetRemote(via.UdpAddr)
return true
}

View file

@ -222,6 +222,13 @@ func (f *Interface) activate() {
WithField("boringcrypto", boringEnabled()).
Info("Nebula interface is active")
if f.routines > 1 {
if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() {
f.routines = 1
f.l.Warn("routines is not supported on this platform, falling back to a single routine")
}
}
metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines))
// Prepare n tun queues
@ -272,7 +279,7 @@ func (f *Interface) listenOut(i int) {
nb := make([]byte, 12, 12)
li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
})
}

View file

@ -19,21 +19,21 @@ const (
minFwPacketLen = 4
)
func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
err := h.Parse(packet)
if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
if len(packet) > 1 {
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
}
return
}
//l.Error("in packet ", header, packet[HeaderLen:])
if ip.IsValid() {
if f.myVpnNetworksTable.Contains(ip.Addr()) {
if !via.IsRelayed {
if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
}
return
}
@ -54,8 +54,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
switch h.Type {
case header.Message:
// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
if !f.handleEncrypted(ci, ip, h) {
if !f.handleEncrypted(ci, via, h) {
return
}
@ -79,7 +78,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
// Successfully validated the thing. Get rid of the Relay header.
signedPayload = signedPayload[header.Len:]
// Pull the Roaming parts up here, and return in all call paths.
f.handleHostRoaming(hostinfo, ip)
f.handleHostRoaming(hostinfo, via)
// Track usage of both the HostInfo and the Relay for the received & authenticated packet
f.connectionManager.In(hostinfo)
f.connectionManager.RelayUsed(h.RemoteIndex)
@ -96,7 +95,14 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case TerminalType:
// If I am the target of this relay, process the unwrapped packet
// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
via = ViaSender{
UdpAddr: via.UdpAddr,
relayHI: hostinfo,
remoteIdx: relay.RemoteIndex,
relay: relay,
IsRelayed: true,
}
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
return
case ForwardingType:
// Find the target HostInfo relay object
@ -126,31 +132,32 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case header.LightHouse:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt lighthouse packet")
return
}
lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
//TODO: assert via is not relayed
lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
// Fallthrough to the bottom to record incoming traffic
case header.Test:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt test packet")
return
@ -159,7 +166,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
if h.Subtype == header.TestRequest {
// This testRequest might be from TryPromoteBest, so we should roam
// to the new IP address before responding
f.handleHostRoaming(hostinfo, ip)
f.handleHostRoaming(hostinfo, via)
f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
}
@ -170,34 +177,34 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
case header.Handshake:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handshakeManager.HandleIncoming(ip, via, packet, h)
f.handshakeManager.HandleIncoming(via, packet, h)
return
case header.RecvError:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
f.handleRecvError(ip, h)
f.handleRecvError(via.UdpAddr, h)
return
case header.CloseTunnel:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
if !f.handleEncrypted(ci, ip, h) {
if !f.handleEncrypted(ci, via, h) {
return
}
hostinfo.logger(f.l).WithField("udpAddr", ip).
hostinfo.logger(f.l).WithField("from", via).
Info("Close tunnel received, tearing down.")
f.closeTunnel(hostinfo)
return
case header.Control:
if !f.handleEncrypted(ci, ip, h) {
if !f.handleEncrypted(ci, via, h) {
return
}
d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
if err != nil {
hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
hostinfo.logger(f.l).WithError(err).WithField("from", via).
WithField("packet", packet).
Error("Failed to decrypt Control packet")
return
@ -207,11 +214,11 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
default:
f.messageMetrics.Rx(h.Type, h.Subtype, 1)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
return
}
f.handleHostRoaming(hostinfo, ip)
f.handleHostRoaming(hostinfo, via)
f.connectionManager.In(hostinfo)
}
@ -230,36 +237,36 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
}
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
if udpAddr.IsValid() && hostinfo.remote != udpAddr {
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming")
return
}
if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
}
return
}
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
Info("Host roamed to new udp ip/port.")
hostinfo.lastRoam = time.Now()
hostinfo.lastRoamRemote = hostinfo.remote
hostinfo.SetRemote(udpAddr)
hostinfo.SetRemote(via.UdpAddr)
}
}
// handleEncrypted returns true if a packet should be processed, false otherwise
func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool {
// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
if ci == nil {
if addr.IsValid() {
f.maybeSendRecvError(addr, h.RemoteIndex)
if !via.IsRelayed {
f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
}
return false
}

View file

@ -13,5 +13,6 @@ type Device interface {
Networks() []netip.Prefix
Name() string
RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool
NewMultiQueueReader() (io.ReadWriteCloser, error)
}

View file

@ -95,6 +95,10 @@ func (t *tun) Name() string {
return "android"
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
}

View file

@ -549,6 +549,10 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}

View file

@ -105,6 +105,10 @@ func (t *disabledTun) Write(b []byte) (int, error) {
return len(b), nil
}
func (t *disabledTun) SupportsMultiqueue() bool {
return true
}
func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return t, nil
}

View file

@ -450,6 +450,10 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}

View file

@ -151,6 +151,10 @@ func (t *tun) Name() string {
return "iOS"
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
}

View file

@ -216,6 +216,10 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) SupportsMultiqueue() bool {
return true
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
@ -582,48 +586,42 @@ func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
}
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
var gateways routing.Gateways
link, err := netlink.LinkByName(t.Device)
if err != nil {
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name")
return gateways
}
// If this route is relevant to our interface and there is a gateway then add it
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
if r.LinkIndex == link.Attrs().Index {
gwAddr, ok := getGatewayAddr(r.Gw, r.Via)
if ok {
if t.isGatewayInVpnNetworks(gwAddr) {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
}
} else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
}
}
for _, p := range r.MultiPath {
// If this route is relevant to our interface and there is a gateway then add it
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(p.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
// p.Hops+1 = weight of the route
if p.LinkIndex == link.Attrs().Index {
gwAddr, ok := getGatewayAddr(p.Gw, p.Via)
if ok {
if t.isGatewayInVpnNetworks(gwAddr) {
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
} else {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network")
}
} else {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address")
}
}
}
@ -632,10 +630,27 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
return gateways
}
func getGatewayAddr(gw net.IP, via netlink.Destination) (netip.Addr, bool) {
// Try to use the old RTA_GATEWAY first
gwAddr, ok := netip.AddrFromSlice(gw)
if !ok {
// Fallback to the new RTA_VIA
rVia, ok := via.(*netlink.Via)
if ok {
gwAddr, ok = netip.AddrFromSlice(rVia.Addr)
}
}
if gwAddr.IsValid() {
gwAddr = gwAddr.Unmap()
return gwAddr, true
}
return netip.Addr{}, false
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")

View file

@ -390,6 +390,10 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
}

View file

@ -310,6 +310,10 @@ func (t *tun) Name() string {
return t.Device
}
func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
}

View file

@ -132,6 +132,10 @@ func (t *TestTun) Read(b []byte) (int, error) {
return len(p), nil
}
func (t *TestTun) SupportsMultiqueue() bool {
return false
}
func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented")
}

View file

@ -234,6 +234,10 @@ func (t *winTun) Write(b []byte) (int, error) {
return t.tun.Write(b, 0)
}
func (t *winTun) SupportsMultiqueue() bool {
return false
}
func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}

View file

@ -46,6 +46,10 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)}
}
func (d *UserDevice) SupportsMultiqueue() bool {
return true
}
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil
}

View file

@ -338,21 +338,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
}
// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
func (r *RemoteList) BlockRemote(bad netip.AddrPort) {
if !bad.IsValid() {
// relays can have nil udp Addrs
func (r *RemoteList) BlockRemote(bad ViaSender) {
if bad.IsRelayed {
return
}
r.Lock()
defer r.Unlock()
// Check if we already blocked this addr
if r.unlockedIsBad(bad) {
if r.unlockedIsBad(bad.UdpAddr) {
return
}
// We copy here because we are taking something else's memory and we can't trust everything
r.badRemotes = append(r.badRemotes, bad)
r.badRemotes = append(r.badRemotes, bad.UdpAddr)
// Mark the next interaction must recollect/dedupe
r.shouldRebuild = true

View file

@ -34,6 +34,10 @@ func (NoopTun) Write([]byte) (int, error) {
return 0, nil
}
func (NoopTun) SupportsMultiqueue() bool {
return false
}
func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return nil, errors.New("unsupported")
}

View file

@ -19,6 +19,7 @@ type Conn interface {
ListenOut(r EncReader)
WriteTo(b []byte, addr netip.AddrPort) error
ReloadConfig(c *config.C)
SupportsMultipleReaders() bool
Close() error
}
@ -33,6 +34,9 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) {
func (NoopConn) ListenOut(_ EncReader) {
return
}
func (NoopConn) SupportsMultipleReaders() bool {
return false
}
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil
}

View file

@ -98,9 +98,9 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
return ErrInvalidIPv6RemoteForSocket
}
var rsa unix.RawSockaddrInet6
rsa.Family = unix.AF_INET6
rsa.Addr = ap.Addr().As16()
var rsa unix.RawSockaddrInet4
rsa.Family = unix.AF_INET
rsa.Addr = ap.Addr().As4()
binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port())
sa = unsafe.Pointer(&rsa)
addrLen = syscall.SizeofSockaddrInet4
@ -184,6 +184,10 @@ func (u *StdConn) ListenOut(r EncReader) {
}
}
func (u *StdConn) SupportsMultipleReaders() bool {
return false
}
func (u *StdConn) Rebind() error {
var err error
if u.isV4 {

View file

@ -85,3 +85,7 @@ func (u *GenericConn) ListenOut(r EncReader) {
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n])
}
}
func (u *GenericConn) SupportsMultipleReaders() bool {
return false
}

View file

@ -72,6 +72,10 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in
return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err
}
func (u *StdConn) SupportsMultipleReaders() bool {
return true
}
func (u *StdConn) Rebind() error {
return nil
}

View file

@ -315,6 +315,10 @@ func (u *RIOConn) LocalAddr() (netip.AddrPort, error) {
}
func (u *RIOConn) SupportsMultipleReaders() bool {
return false
}
func (u *RIOConn) Rebind() error {
return nil
}

View file

@ -127,6 +127,10 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) {
return u.Addr, nil
}
func (u *TesterConn) SupportsMultipleReaders() bool {
return false
}
func (u *TesterConn) Rebind() error {
return nil
}