diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml index 288f32ce..51575623 100644 --- a/.github/workflows/gofmt.yml +++ b/.github/workflows/gofmt.yml @@ -14,11 +14,11 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: - go-version: '1.24' + go-version: '1.25' check-latest: true - name: Install goimports diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3107b474..35d72dea 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,11 +10,11 @@ jobs: name: Build Linux/BSD All runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: - go-version: '1.24' + go-version: '1.25' check-latest: true - name: Build @@ -24,7 +24,7 @@ jobs: mv build/*.tar.gz release - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: linux-latest path: release @@ -33,11 +33,11 @@ jobs: name: Build Windows runs-on: windows-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: - go-version: '1.24' + go-version: '1.25' check-latest: true - name: Build @@ -55,7 +55,7 @@ jobs: mv dist\windows\wintun build\dist\windows\ - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: windows-latest path: build @@ -66,11 +66,11 @@ jobs: HAS_SIGNING_CREDS: ${{ secrets.AC_USERNAME != '' }} runs-on: macos-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: - go-version: '1.24' + go-version: '1.25' check-latest: true - name: Import certificates @@ -104,7 +104,7 @@ jobs: fi - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: darwin-latest path: ./release/* @@ -124,11 +124,11 @@ jobs: # be overwritten - name: Checkout code if: ${{ env.HAS_DOCKER_CREDS == 'true' }} - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Download artifacts if: ${{ env.HAS_DOCKER_CREDS == 'true' }} - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v6 with: name: linux-latest path: artifacts @@ -160,10 +160,10 @@ jobs: needs: [build-linux, build-darwin, build-windows] runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Download artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v6 with: path: artifacts diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml index de582de9..966bddd2 100644 --- a/.github/workflows/smoke-extra.yml +++ b/.github/workflows/smoke-extra.yml @@ -20,11 +20,11 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: - go-version-file: 'go.mod' + go-version: '1.25' check-latest: true - name: add hashicorp source diff --git a/.github/workflows/smoke.yml b/.github/workflows/smoke.yml index c4eac12f..0ccae47b 100644 --- a/.github/workflows/smoke.yml +++ b/.github/workflows/smoke.yml @@ -18,11 +18,11 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: - go-version: '1.24' + go-version: '1.25' check-latest: true - name: build diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 00b3936b..050a68e4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,11 +18,11 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: - go-version: '1.24' + go-version: '1.25' check-latest: true - name: Build @@ -32,9 +32,9 @@ jobs: run: make vet - name: golangci-lint - uses: golangci/golangci-lint-action@v8 + uses: golangci/golangci-lint-action@v9 with: - version: v2.1 + version: v2.5 - name: Test run: make test @@ -45,7 +45,7 @@ jobs: - name: Build test mobile run: make build-test-mobile - - uses: actions/upload-artifact@v4 + - uses: actions/upload-artifact@v5 with: name: e2e packet flow linux-latest path: e2e/mermaid/linux-latest @@ -56,11 +56,11 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: - go-version: '1.24' + go-version: '1.25' check-latest: true - name: Build @@ -77,11 +77,11 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: - go-version: '1.22' + go-version: '1.25' check-latest: true - name: Build @@ -98,11 +98,11 @@ jobs: os: [windows-latest, macos-latest] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: - go-version: '1.24' + go-version: '1.25' check-latest: true - name: Build nebula @@ -115,9 +115,9 @@ jobs: run: make vet - name: golangci-lint - uses: golangci/golangci-lint-action@v8 + uses: golangci/golangci-lint-action@v9 with: - version: v2.1 + version: v2.5 - name: Test run: make test @@ -125,7 +125,7 @@ jobs: - name: End 2 end run: make e2evv - - uses: actions/upload-artifact@v4 + - uses: actions/upload-artifact@v5 with: name: e2e packet flow ${{ matrix.os }} path: e2e/mermaid/${{ matrix.os }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 1de3c196..0efa7959 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,12 +7,85 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.10.0] - 2025-12-04 + +See the [v1.10.0](https://github.com/slackhq/nebula/milestone/16?closed=1) milestone for a complete list of changes. + +### Added + +- Support for ipv6 and multiple ipv4/6 addresses in the overlay. + A new v2 ASN.1 based certificate format. + Certificates now have a unified interface for external implementations. + (#1212, #1216, #1345, #1359, #1381, #1419, #1464, #1466, #1451, #1476, #1467, #1481, #1399, #1488, #1492, #1495, #1468, #1521, #1535, #1538) +- Add the ability to mark packets on linux to better target nebula packets in iptables/nftables. (#1331) +- Add ECMP support for `unsafe_routes`. (#1332) +- PKCS11 support for P256 keys when built with `pkcs11` tag (#1153, #1482) + ### Changed -- `default_local_cidr_any` now defaults to false, meaning that any firewall rule +- **NOTE**: `default_local_cidr_any` now defaults to false, meaning that any firewall rule intended to target an `unsafe_routes` entry must explicitly declare it via the `local_cidr` field. This is almost always the intended behavior. This flag is - deprecated and will be removed in a future release. + deprecated and will be removed in a future release. (#1373) +- Improve logging when a relay is in use on an inbound packet. (#1533) +- Avoid fatal errors if `rountines` is > 1 on systems that don't support more than 1 routine. (#1531) +- Log a warning if a firewall rule contains an `any` that negates a more restrictive filter. (#1513) +- Accept encrypted CA passphrase from an environment variable. (#1421) +- Allow handshaking with any trusted remote. (#1509) +- Log only the count of blocklisted certificate fingerprints instead of the entire list. (#1525) +- Don't fatal when the ssh server is unable to be configured successfully. (#1520) +- Update to build against go v1.25. (#1483) +- Allow projects using `nebula` as a library with userspace networking to configure the `logger` and build version. (#1239) +- Upgrade to `yaml.v3`. (#1148, #1371, #1438, #1478) + +### Fixed + +- Fix a potential bug with udp ipv4 only on darwin. (#1532) +- Improve lost packet statistics. (#1441, #1537) +- Honor `remote_allow_list` in hole punch response. (#1186) +- Fix a panic when `tun.use_system_route_table` is `true` and a route lacks a destination. (#1437) +- Fix an issue when `tun.use_system_route_table: true` could result in heavy CPU utilization when many thousands of routes + are present. (#1326) +- Fix tests for 32 bit machines. (#1394) +- Fix a possible 32bit integer underflow in config handling. (#1353) +- Fix moving a udp address from one vpn address to another in the `static_host_map` + which could cause rapid re-handshaking with an incorrect remote. (#1259) +- Improve smoke tests in environments where the docker network is not the default. (#1347) + +## [1.9.7] - 2025-10-10 + +### Security + +- Fix an issue where Nebula could incorrectly accept and process a packet from an erroneous source IP when the sender's + certificate is configured with unsafe_routes (cert v1/v2) or multiple IPs (cert v2). (#1494) + +### Changed + +- Disable sending `recv_error` messages when a packet is received outside the allowable counter window. (#1459) +- Improve error messages and remove some unnecessary fatal conditions in the Windows and generic udp listener. (#1453) + +## [1.9.6] - 2025-7-15 + +### Added + +- Support dropping inactive tunnels. This is disabled by default in this release but can be enabled with `tunnels.drop_inactive`. See example config for more details. (#1413) + +### Fixed + +- Fix Darwin freeze due to presence of some Network Extensions (#1426) +- Ensure the same relay tunnel is always used when multiple relay tunnels are present (#1422) +- Fix Windows freeze due to ICMP error handling (#1412) +- Fix relay migration panic (#1403) + +## [1.9.5] - 2024-12-05 + +### Added + +- Gracefully ignore v2 certificates. (#1282) + +### Fixed + +- Fix relays that refuse to re-establish after one of the remote tunnel pairs breaks. (#1277) ## [1.9.4] - 2024-09-09 @@ -671,7 +744,11 @@ created.) - Initial public release. -[Unreleased]: https://github.com/slackhq/nebula/compare/v1.9.4...HEAD +[Unreleased]: https://github.com/slackhq/nebula/compare/v1.10.0...HEAD +[1.10.0]: https://github.com/slackhq/nebula/releases/tag/v1.10.0 +[1.9.7]: https://github.com/slackhq/nebula/releases/tag/v1.9.7 +[1.9.6]: https://github.com/slackhq/nebula/releases/tag/v1.9.6 +[1.9.5]: https://github.com/slackhq/nebula/releases/tag/v1.9.5 [1.9.4]: https://github.com/slackhq/nebula/releases/tag/v1.9.4 [1.9.3]: https://github.com/slackhq/nebula/releases/tag/v1.9.3 [1.9.2]: https://github.com/slackhq/nebula/releases/tag/v1.9.2 diff --git a/bits.go b/bits.go index b4f96c6c..af11cc48 100644 --- a/bits.go +++ b/bits.go @@ -9,14 +9,13 @@ type Bits struct { length uint64 current uint64 bits []bool - firstSeen bool lostCounter metrics.Counter dupeCounter metrics.Counter outOfWindowCounter metrics.Counter } func NewBits(bits uint64) *Bits { - return &Bits{ + b := &Bits{ length: bits, bits: make([]bool, bits, bits), current: 0, @@ -24,34 +23,37 @@ func NewBits(bits uint64) *Bits { dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil), } + + // There is no counter value 0, mark it to avoid counting a lost packet later. + b.bits[0] = true + b.current = 0 + return b } -func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool { +func (b *Bits) Check(l *logrus.Logger, i uint64) bool { // If i is the next number, return true. - if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) { + if i > b.current { return true } - // If i is within the window, check if it's been set already. The first window will fail this check - if i > b.current-b.length { - return !b.bits[i%b.length] - } - - // If i is within the first window - if i < b.length { + // If i is within the window, check if it's been set already. + if i > b.current-b.length || i < b.length && b.current < b.length { return !b.bits[i%b.length] } // Not within the window - l.Debugf("rejected a packet (top) %d %d\n", b.current, i) + if l.Level >= logrus.DebugLevel { + l.Debugf("rejected a packet (top) %d %d\n", b.current, i) + } return false } func (b *Bits) Update(l *logrus.Logger, i uint64) bool { // If i is the next number, return true and update current. if i == b.current+1 { - // Report missed packets, we can only understand what was missed after the first window has been gone through - if i > b.length && b.bits[i%b.length] == false { + // Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter + // The very first window can only be tracked as lost once we are on the 2nd window or greater + if b.bits[i%b.length] == false && i > b.length { b.lostCounter.Inc(1) } b.bits[i%b.length] = true @@ -59,61 +61,32 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool { return true } - // If i packet is greater than current but less than the maximum length of our bitmap, - // flip everything in between to false and move ahead. - if i > b.current && i < b.current+b.length { - // In between current and i need to be zero'd to allow those packets to come in later - for n := b.current + 1; n < i; n++ { + // If i is a jump, adjust the window, record lost, update current, and return true + if i > b.current { + lost := int64(0) + // Zero out the bits between the current and the new counter value, limited by the window size, + // since the window is shifting + for n := b.current + 1; n <= min(i, b.current+b.length); n++ { + if b.bits[n%b.length] == false && n > b.length { + lost++ + } b.bits[n%b.length] = false } - b.bits[i%b.length] = true - b.current = i - //l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current) - return true - } - - // If i is greater than the delta between current and the total length of our bitmap, - // just flip everything in the map and move ahead. - if i >= b.current+b.length { - // The current window loss will be accounted for later, only record the jump as loss up until then - lost := maxInt64(0, int64(i-b.current-b.length)) - //TODO: explain this - if b.current == 0 { - lost++ - } - - for n := range b.bits { - // Don't want to count the first window as a loss - //TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed - //if b.bits[n] == false { - // lost++ - //} - b.bits[n] = false - } - + // Only record any skipped packets as a result of the window moving further than the window length + // Any loss within the new window will be accounted for in future calls + lost += max(0, int64(i-b.current-b.length)) b.lostCounter.Inc(lost) - if l.Level >= logrus.DebugLevel { - l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}). - Debug("Receive window") - } b.bits[i%b.length] = true b.current = i return true } - // Allow for the 0 packet to come in within the first window - if i == 0 && b.firstSeen == false && b.current < b.length { - b.firstSeen = true - b.bits[i%b.length] = true - return true - } - - // If i is within the window of current minus length (the total pat window size), - // allow it and flip to true but to NOT change current. We also have to account for the first window - if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current { - if b.current == i { + // If i is within the current window but below the current counter, + // Check to see if it's a duplicate + if i > b.current-b.length || i < b.length && b.current < b.length { + if b.current == i || b.bits[i%b.length] == true { if l.Level >= logrus.DebugLevel { l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}). Debug("Receive window") @@ -122,18 +95,8 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool { return false } - if b.bits[i%b.length] == true { - if l.Level >= logrus.DebugLevel { - l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}). - Debug("Receive window") - } - b.dupeCounter.Inc(1) - return false - } - b.bits[i%b.length] = true return true - } // In all other cases, fail and don't change current. @@ -147,11 +110,3 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool { } return false } - -func maxInt64(a, b int64) int64 { - if a > b { - return a - } - - return b -} diff --git a/bits_test.go b/bits_test.go index 95abe018..3504cefa 100644 --- a/bits_test.go +++ b/bits_test.go @@ -15,48 +15,41 @@ func TestBits(t *testing.T) { assert.Len(t, b.bits, 10) // This is initialized to zero - receive one. This should work. - assert.True(t, b.Check(l, 1)) - u := b.Update(l, 1) - assert.True(t, u) + assert.True(t, b.Update(l, 1)) assert.EqualValues(t, 1, b.current) - g := []bool{false, true, false, false, false, false, false, false, false, false} + g := []bool{true, true, false, false, false, false, false, false, false, false} assert.Equal(t, g, b.bits) // Receive two assert.True(t, b.Check(l, 2)) - u = b.Update(l, 2) - assert.True(t, u) + assert.True(t, b.Update(l, 2)) assert.EqualValues(t, 2, b.current) - g = []bool{false, true, true, false, false, false, false, false, false, false} + g = []bool{true, true, true, false, false, false, false, false, false, false} assert.Equal(t, g, b.bits) // Receive two again - it will fail assert.False(t, b.Check(l, 2)) - u = b.Update(l, 2) - assert.False(t, u) + assert.False(t, b.Update(l, 2)) assert.EqualValues(t, 2, b.current) // Jump ahead to 15, which should clear everything and set the 6th element assert.True(t, b.Check(l, 15)) - u = b.Update(l, 15) - assert.True(t, u) + assert.True(t, b.Update(l, 15)) assert.EqualValues(t, 15, b.current) g = []bool{false, false, false, false, false, true, false, false, false, false} assert.Equal(t, g, b.bits) // Mark 14, which is allowed because it is in the window assert.True(t, b.Check(l, 14)) - u = b.Update(l, 14) - assert.True(t, u) + assert.True(t, b.Update(l, 14)) assert.EqualValues(t, 15, b.current) g = []bool{false, false, false, false, true, true, false, false, false, false} assert.Equal(t, g, b.bits) // Mark 5, which is not allowed because it is not in the window assert.False(t, b.Check(l, 5)) - u = b.Update(l, 5) - assert.False(t, u) + assert.False(t, b.Update(l, 5)) assert.EqualValues(t, 15, b.current) g = []bool{false, false, false, false, true, true, false, false, false, false} assert.Equal(t, g, b.bits) @@ -69,10 +62,29 @@ func TestBits(t *testing.T) { // Walk through a few windows in order b = NewBits(10) - for i := uint64(0); i <= 100; i++ { + for i := uint64(1); i <= 100; i++ { assert.True(t, b.Check(l, i), "Error while checking %v", i) assert.True(t, b.Update(l, i), "Error while updating %v", i) } + + assert.False(t, b.Check(l, 1), "Out of window check") +} + +func TestBitsLargeJumps(t *testing.T) { + l := test.NewLogger() + b := NewBits(10) + b.lostCounter.Clear() + + b = NewBits(10) + b.lostCounter.Clear() + assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54 + assert.Equal(t, int64(45), b.lostCounter.Count()) + + assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99 + assert.Equal(t, int64(89), b.lostCounter.Count()) + + assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199 + assert.Equal(t, int64(188), b.lostCounter.Count()) } func TestBitsDupeCounter(t *testing.T) { @@ -124,8 +136,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) { assert.False(t, b.Update(l, 0)) assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) - //tODO: make sure lostcounter doesn't increase in orderly increment - assert.Equal(t, int64(20), b.lostCounter.Count()) + assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) } @@ -137,8 +148,6 @@ func TestBitsLostCounter(t *testing.T) { b.dupeCounter.Clear() b.outOfWindowCounter.Clear() - //assert.True(t, b.Update(0)) - assert.True(t, b.Update(l, 0)) assert.True(t, b.Update(l, 20)) assert.True(t, b.Update(l, 21)) assert.True(t, b.Update(l, 22)) @@ -149,7 +158,7 @@ func TestBitsLostCounter(t *testing.T) { assert.True(t, b.Update(l, 27)) assert.True(t, b.Update(l, 28)) assert.True(t, b.Update(l, 29)) - assert.Equal(t, int64(20), b.lostCounter.Count()) + assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) @@ -158,8 +167,6 @@ func TestBitsLostCounter(t *testing.T) { b.dupeCounter.Clear() b.outOfWindowCounter.Clear() - assert.True(t, b.Update(l, 0)) - assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 9)) assert.Equal(t, int64(0), b.lostCounter.Count()) // 10 will set 0 index, 0 was already set, no lost packets @@ -214,6 +221,62 @@ func TestBitsLostCounter(t *testing.T) { assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) } +func TestBitsLostCounterIssue1(t *testing.T) { + l := test.NewLogger() + b := NewBits(10) + b.lostCounter.Clear() + b.dupeCounter.Clear() + b.outOfWindowCounter.Clear() + + assert.True(t, b.Update(l, 4)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + assert.True(t, b.Update(l, 1)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + assert.True(t, b.Update(l, 9)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + assert.True(t, b.Update(l, 2)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + assert.True(t, b.Update(l, 3)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + assert.True(t, b.Update(l, 5)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + assert.True(t, b.Update(l, 6)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + assert.True(t, b.Update(l, 7)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + // assert.True(t, b.Update(l, 8)) + assert.True(t, b.Update(l, 10)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + assert.True(t, b.Update(l, 11)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + + assert.True(t, b.Update(l, 14)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + // Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter + assert.True(t, b.Update(l, 19)) + assert.Equal(t, int64(1), b.lostCounter.Count()) + assert.True(t, b.Update(l, 12)) + assert.Equal(t, int64(1), b.lostCounter.Count()) + assert.True(t, b.Update(l, 13)) + assert.Equal(t, int64(1), b.lostCounter.Count()) + assert.True(t, b.Update(l, 15)) + assert.Equal(t, int64(1), b.lostCounter.Count()) + assert.True(t, b.Update(l, 16)) + assert.Equal(t, int64(1), b.lostCounter.Count()) + assert.True(t, b.Update(l, 17)) + assert.Equal(t, int64(1), b.lostCounter.Count()) + assert.True(t, b.Update(l, 18)) + assert.Equal(t, int64(1), b.lostCounter.Count()) + assert.True(t, b.Update(l, 20)) + assert.Equal(t, int64(1), b.lostCounter.Count()) + assert.True(t, b.Update(l, 21)) + + // We missed packet 8 above + assert.Equal(t, int64(1), b.lostCounter.Count()) + assert.Equal(t, int64(0), b.dupeCounter.Count()) + assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) +} + func BenchmarkBits(b *testing.B) { z := NewBits(10) for n := 0; n < b.N; n++ { diff --git a/calculated_remote.go b/calculated_remote.go index 32d062a8..0e28bb42 100644 --- a/calculated_remote.go +++ b/calculated_remote.go @@ -84,16 +84,11 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu calculatedRemotes := new(bart.Table[[]*calculatedRemote]) - rawMap, ok := value.(map[any]any) + rawMap, ok := value.(map[string]any) if !ok { return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value) } - for rawKey, rawValue := range rawMap { - rawCIDR, ok := rawKey.(string) - if !ok { - return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey) - } - + for rawCIDR, rawValue := range rawMap { cidr, err := netip.ParsePrefix(rawCIDR) if err != nil { return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) @@ -129,7 +124,7 @@ func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculat } func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) { - rawMap, ok := raw.(map[any]any) + rawMap, ok := raw.(map[string]any) if !ok { return nil, fmt.Errorf("invalid type: %T", raw) } diff --git a/cert/cert.go b/cert/cert.go index 38a25287..855815a7 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -58,6 +58,9 @@ type Certificate interface { // PublicKey is the raw bytes to be used in asymmetric cryptographic operations. PublicKey() []byte + // MarshalPublicKeyPEM is the value of PublicKey marshalled to PEM + MarshalPublicKeyPEM() []byte + // Curve identifies which curve was used for the PublicKey and Signature. Curve() Curve @@ -135,8 +138,7 @@ func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certific case Version2: c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve) default: - //TODO: CERT-V2 make a static var - return nil, fmt.Errorf("unknown certificate version %d", v) + return nil, ErrUnknownVersion } if err != nil { diff --git a/cert/cert_v1.go b/cert/cert_v1.go index 71d36eb8..09a181d6 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -83,6 +83,10 @@ func (c *certificateV1) PublicKey() []byte { return c.details.publicKey } +func (c *certificateV1) MarshalPublicKeyPEM() []byte { + return marshalCertPublicKeyToPEM(c) +} + func (c *certificateV1) Signature() []byte { return c.signature } @@ -110,8 +114,10 @@ func (c *certificateV1) CheckSignature(key []byte) bool { case Curve_CURVE25519: return ed25519.Verify(key, b, c.signature) case Curve_P256: - x, y := elliptic.Unmarshal(elliptic.P256(), key) - pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} + pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key) + if err != nil { + return false + } hashed := sha256.Sum256(b) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) default: diff --git a/cert/cert_v1_test.go b/cert/cert_v1_test.go index c687172c..3b7d5859 100644 --- a/cert/cert_v1_test.go +++ b/cert/cert_v1_test.go @@ -1,6 +1,7 @@ package cert import ( + "crypto/ed25519" "fmt" "net/netip" "testing" @@ -13,6 +14,7 @@ import ( ) func TestCertificateV1_Marshal(t *testing.T) { + t.Parallel() before := time.Now().Add(time.Second * -60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := []byte("1234567890abcedfghij1234567890ab") @@ -60,6 +62,58 @@ func TestCertificateV1_Marshal(t *testing.T) { assert.Equal(t, nc.Groups(), nc2.Groups()) } +func TestCertificateV1_PublicKeyPem(t *testing.T) { + t.Parallel() + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab") + + nc := certificateV1{ + details: detailsV1{ + name: "testing", + networks: []netip.Prefix{}, + unsafeNetworks: []netip.Prefix{}, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + assert.Equal(t, Version1, nc.Version()) + assert.Equal(t, Curve_CURVE25519, nc.Curve()) + pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n" + assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem) + assert.False(t, nc.IsCA()) + + nc.details.isCA = true + assert.Equal(t, Curve_CURVE25519, nc.Curve()) + pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n" + assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem) + assert.True(t, nc.IsCA()) + + pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA +AAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA P256 PUBLIC KEY----- +`) + pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem) + require.NoError(t, err) + nc.details.curve = Curve_P256 + nc.details.publicKey = pubP256Key + assert.Equal(t, Curve_P256, nc.Curve()) + assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem)) + assert.True(t, nc.IsCA()) + + nc.details.isCA = false + assert.Equal(t, Curve_P256, nc.Curve()) + assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem)) + assert.False(t, nc.IsCA()) +} + func TestCertificateV1_Expired(t *testing.T) { nc := certificateV1{ details: detailsV1{ diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 322463e9..ac21cb13 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -114,6 +114,10 @@ func (c *certificateV2) PublicKey() []byte { return c.publicKey } +func (c *certificateV2) MarshalPublicKeyPEM() []byte { + return marshalCertPublicKeyToPEM(c) +} + func (c *certificateV2) Signature() []byte { return c.signature } @@ -149,8 +153,10 @@ func (c *certificateV2) CheckSignature(key []byte) bool { case Curve_CURVE25519: return ed25519.Verify(key, b, c.signature) case Curve_P256: - x, y := elliptic.Unmarshal(elliptic.P256(), key) - pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} + pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key) + if err != nil { + return false + } hashed := sha256.Sum256(b) return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) default: diff --git a/cert/cert_v2_test.go b/cert/cert_v2_test.go index c84f8c99..84362efe 100644 --- a/cert/cert_v2_test.go +++ b/cert/cert_v2_test.go @@ -15,6 +15,7 @@ import ( ) func TestCertificateV2_Marshal(t *testing.T) { + t.Parallel() before := time.Now().Add(time.Second * -60).Round(time.Second) after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := []byte("1234567890abcedfghij1234567890ab") @@ -75,6 +76,58 @@ func TestCertificateV2_Marshal(t *testing.T) { assert.Equal(t, nc.Groups(), nc2.Groups()) } +func TestCertificateV2_PublicKeyPem(t *testing.T) { + t.Parallel() + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := ed25519.PublicKey("1234567890abcedfghij1234567890ab") + + nc := certificateV2{ + details: detailsV2{ + name: "testing", + networks: []netip.Prefix{}, + unsafeNetworks: []netip.Prefix{}, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + publicKey: pubKey, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + assert.Equal(t, Version2, nc.Version()) + assert.Equal(t, Curve_CURVE25519, nc.Curve()) + pubPem := "-----BEGIN NEBULA X25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA X25519 PUBLIC KEY-----\n" + assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem) + assert.False(t, nc.IsCA()) + + nc.details.isCA = true + assert.Equal(t, Curve_CURVE25519, nc.Curve()) + pubPem = "-----BEGIN NEBULA ED25519 PUBLIC KEY-----\nMTIzNDU2Nzg5MGFiY2VkZmdoaWoxMjM0NTY3ODkwYWI=\n-----END NEBULA ED25519 PUBLIC KEY-----\n" + assert.Equal(t, string(nc.MarshalPublicKeyPEM()), pubPem) + assert.True(t, nc.IsCA()) + + pubP256KeyPem := []byte(`-----BEGIN NEBULA P256 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA +AAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA P256 PUBLIC KEY----- +`) + pubP256Key, _, _, err := UnmarshalPublicKeyFromPEM(pubP256KeyPem) + require.NoError(t, err) + nc.curve = Curve_P256 + nc.publicKey = pubP256Key + assert.Equal(t, Curve_P256, nc.Curve()) + assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem)) + assert.True(t, nc.IsCA()) + + nc.details.isCA = false + assert.Equal(t, Curve_P256, nc.Curve()) + assert.Equal(t, string(nc.MarshalPublicKeyPEM()), string(pubP256KeyPem)) + assert.False(t, nc.IsCA()) +} + func TestCertificateV2_Expired(t *testing.T) { nc := certificateV2{ details: detailsV2{ diff --git a/cert/errors.go b/cert/errors.go index 4bbc023a..99006756 100644 --- a/cert/errors.go +++ b/cert/errors.go @@ -20,6 +20,7 @@ var ( ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair") ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") ErrCaNotFound = errors.New("could not find ca for the certificate") + ErrUnknownVersion = errors.New("certificate version unrecognized") ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block") ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner") diff --git a/cert/pem.go b/cert/pem.go index 7ad28d12..a5aabdce 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -7,19 +7,26 @@ import ( "golang.org/x/crypto/ed25519" ) -const ( - CertificateBanner = "NEBULA CERTIFICATE" - CertificateV2Banner = "NEBULA CERTIFICATE V2" - X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" - X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" - EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" - Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" - Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" +const ( //cert banners + CertificateBanner = "NEBULA CERTIFICATE" + CertificateV2Banner = "NEBULA CERTIFICATE V2" +) - P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY" - P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY" +const ( //key-agreement-key banners + X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" + X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" + P256PrivateKeyBanner = "NEBULA P256 PRIVATE KEY" + P256PublicKeyBanner = "NEBULA P256 PUBLIC KEY" +) + +/* including "ECDSA" in the P256 banners is a clue that these keys should be used only for signing */ +const ( //signing key banners EncryptedECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 ENCRYPTED PRIVATE KEY" ECDSAP256PrivateKeyBanner = "NEBULA ECDSA P256 PRIVATE KEY" + ECDSAP256PublicKeyBanner = "NEBULA ECDSA P256 PUBLIC KEY" + EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" + Ed25519PrivateKeyBanner = "NEBULA ED25519 PRIVATE KEY" + Ed25519PublicKeyBanner = "NEBULA ED25519 PUBLIC KEY" ) // UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed @@ -51,6 +58,16 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { } +func marshalCertPublicKeyToPEM(c Certificate) []byte { + if c.IsCA() { + return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey()) + } else { + return MarshalPublicKeyToPEM(c.Curve(), c.PublicKey()) + } +} + +// MarshalPublicKeyToPEM returns a PEM representation of a public key used for ECDH. +// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes! func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { switch curve { case Curve_CURVE25519: @@ -62,6 +79,19 @@ func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { } } +// MarshalSigningPublicKeyToPEM returns a PEM representation of a public key used for signing. +// if your public key came from a certificate, prefer Certificate.PublicKeyPEM() if possible, to avoid mistakes! +func MarshalSigningPublicKeyToPEM(curve Curve, b []byte) []byte { + switch curve { + case Curve_CURVE25519: + return pem.EncodeToMemory(&pem.Block{Type: Ed25519PublicKeyBanner, Bytes: b}) + case Curve_P256: + return pem.EncodeToMemory(&pem.Block{Type: P256PublicKeyBanner, Bytes: b}) + default: + return nil + } +} + func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { k, r := pem.Decode(b) if k == nil { @@ -73,7 +103,7 @@ func UnmarshalPublicKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) { case X25519PublicKeyBanner, Ed25519PublicKeyBanner: expectedLen = 32 curve = Curve_CURVE25519 - case P256PublicKeyBanner: + case P256PublicKeyBanner, ECDSAP256PublicKeyBanner: // Uncompressed expectedLen = 65 curve = Curve_P256 diff --git a/cert/pem_test.go b/cert/pem_test.go index 6e492497..ff4410ce 100644 --- a/cert/pem_test.go +++ b/cert/pem_test.go @@ -177,6 +177,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= } func TestUnmarshalPublicKeyFromPEM(t *testing.T) { + t.Parallel() pubKey := []byte(`# A good key -----BEGIN NEBULA ED25519 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= @@ -230,6 +231,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= } func TestUnmarshalX25519PublicKey(t *testing.T) { + t.Parallel() pubKey := []byte(`# A good key -----BEGIN NEBULA X25519 PUBLIC KEY----- AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= @@ -240,6 +242,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA AAAAAAAAAAAAAAAAAAAAAAA= -----END NEBULA P256 PUBLIC KEY----- +`) + oldPubP256Key := []byte(`# A good key +-----BEGIN NEBULA ECDSA P256 PUBLIC KEY----- +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA +AAAAAAAAAAAAAAAAAAAAAAA= +-----END NEBULA ECDSA P256 PUBLIC KEY----- `) shortKey := []byte(`# A short key -----BEGIN NEBULA X25519 PUBLIC KEY----- @@ -256,15 +264,22 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= -END NEBULA X25519 PUBLIC KEY-----`) - keyBundle := appendByteSlices(pubKey, pubP256Key, shortKey, invalidBanner, invalidPem) + keyBundle := appendByteSlices(pubKey, pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem) // Success test case k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) assert.Len(t, k, 32) require.NoError(t, err) - assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, rest, appendByteSlices(pubP256Key, oldPubP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) + // Success test case + k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) + assert.Len(t, k, 65) + require.NoError(t, err) + assert.Equal(t, rest, appendByteSlices(oldPubP256Key, shortKey, invalidBanner, invalidPem)) + assert.Equal(t, Curve_P256, curve) + // Success test case k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Len(t, k, 65) diff --git a/cert/sign.go b/cert/sign.go index 12d4ee45..3eb08592 100644 --- a/cert/sign.go +++ b/cert/sign.go @@ -7,7 +7,6 @@ import ( "crypto/rand" "crypto/sha256" "fmt" - "math/big" "net/netip" "time" ) @@ -55,15 +54,10 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert } return t.SignWith(signer, curve, sp) case Curve_P256: - pk := &ecdsa.PrivateKey{ - PublicKey: ecdsa.PublicKey{ - Curve: elliptic.P256(), - }, - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 - D: new(big.Int).SetBytes(key), + pk, err := ecdsa.ParseRawPrivateKey(elliptic.P256(), key) + if err != nil { + return nil, err } - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 - pk.X, pk.Y = pk.Curve.ScalarBaseMult(key) sp := func(certBytes []byte) ([]byte, error) { // We need to hash first for ECDSA // - https://pkg.go.dev/crypto/ecdsa#SignASN1 diff --git a/cert_test/cert.go b/cert_test/cert.go index ebc6f522..75134316 100644 --- a/cert_test/cert.go +++ b/cert_test/cert.go @@ -114,6 +114,33 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem } +func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) { + nc := &cert.TBSCertificate{ + Version: v, + Curve: c.Curve(), + Name: c.Name(), + Networks: c.Networks(), + UnsafeNetworks: c.UnsafeNetworks(), + Groups: c.Groups(), + NotBefore: time.Unix(c.NotBefore().Unix(), 0), + NotAfter: time.Unix(c.NotAfter().Unix(), 0), + PublicKey: c.PublicKey(), + IsCA: false, + } + + c, err := nc.Sign(ca, ca.Curve(), key) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pem +} + func X25519Keypair() ([]byte, []byte) { privkey := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, privkey); err != nil { diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index f83c94fb..cd9b82f9 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -173,23 +173,26 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error var passphrase []byte if !isP11 && *cf.encryption { - for i := 0; i < 5; i++ { - out.Write([]byte("Enter passphrase: ")) - passphrase, err = pr.ReadPassword() - - if err == ErrNoTerminal { - return fmt.Errorf("out-key must be encrypted interactively") - } else if err != nil { - return fmt.Errorf("error reading passphrase: %s", err) - } - - if len(passphrase) > 0 { - break - } - } - + passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE")) if len(passphrase) == 0 { - return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext") + for i := 0; i < 5; i++ { + out.Write([]byte("Enter passphrase: ")) + passphrase, err = pr.ReadPassword() + + if err == ErrNoTerminal { + return fmt.Errorf("out-key must be encrypted interactively") + } else if err != nil { + return fmt.Errorf("error reading passphrase: %s", err) + } + + if len(passphrase) > 0 { + break + } + } + + if len(passphrase) == 0 { + return fmt.Errorf("no passphrase specified, remove -encrypt flag to write out-key in plaintext") + } } } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index b1cbde92..cd3f0bf9 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -171,6 +171,17 @@ func Test_ca(t *testing.T) { assert.Equal(t, pwPromptOb, ob.String()) assert.Empty(t, eb.String()) + // test encrypted key with passphrase environment variable + os.Remove(keyF.Name()) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase)) + require.NoError(t, ca(args, ob, eb, testpw)) + assert.Empty(t, eb.String()) + os.Setenv("NEBULA_CA_PASSPHRASE", "") + // read encrypted key file and verify default params rb, _ = os.ReadFile(keyF.Name()) k, _ := pem.Decode(rb) diff --git a/cmd/nebula-cert/main.go b/cmd/nebula-cert/main.go index c88626f9..82ad2fee 100644 --- a/cmd/nebula-cert/main.go +++ b/cmd/nebula-cert/main.go @@ -5,10 +5,28 @@ import ( "fmt" "io" "os" + "runtime/debug" + "strings" ) +// A version string that can be set with +// +// -ldflags "-X main.Build=SOMEVERSION" +// +// at compile-time. var Build string +func init() { + if Build == "" { + info, ok := debug.ReadBuildInfo() + if !ok { + return + } + + Build = strings.TrimPrefix(info.Main.Version, "v") + } +} + type helpError struct { s string } diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index ebcb592e..561138ca 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -43,7 +43,7 @@ type signFlags struct { func newSignFlags() *signFlags { sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf.set.Usage = func() {} - sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.") + sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") @@ -116,26 +116,28 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) // naively attempt to decode the private key as though it is not encrypted caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) if errors.Is(err, cert.ErrPrivateKeyEncrypted) { - // ask for a passphrase until we get one var passphrase []byte - for i := 0; i < 5; i++ { - out.Write([]byte("Enter passphrase: ")) - passphrase, err = pr.ReadPassword() - - if errors.Is(err, ErrNoTerminal) { - return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") - } else if err != nil { - return fmt.Errorf("error reading password: %s", err) - } - - if len(passphrase) > 0 { - break - } - } + passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE")) if len(passphrase) == 0 { - return fmt.Errorf("cannot open encrypted ca-key without passphrase") - } + // ask for a passphrase until we get one + for i := 0; i < 5; i++ { + out.Write([]byte("Enter passphrase: ")) + passphrase, err = pr.ReadPassword() + if errors.Is(err, ErrNoTerminal) { + return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") + } else if err != nil { + return fmt.Errorf("error reading password: %s", err) + } + + if len(passphrase) > 0 { + break + } + } + if len(passphrase) == 0 { + return fmt.Errorf("cannot open encrypted ca-key without passphrase") + } + } curve, caKey, _, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rawCAKey) if err != nil { return fmt.Errorf("error while parsing encrypted ca-key: %s", err) @@ -165,6 +167,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("ca certificate is expired") } + if version == 0 { + version = caCert.Version() + } + // if no duration is given, expire one second before the root expires if *sf.duration <= 0 { *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 @@ -277,21 +283,19 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) notBefore := time.Now() notAfter := notBefore.Add(*sf.duration) - if version == 0 || version == cert.Version1 { - // Make sure we at least have an ip + switch version { + case cert.Version1: + // Make sure we have only one ipv4 address if len(v4Networks) != 1 { return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") } - if version == cert.Version1 { - // If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses - if len(v6Networks) > 0 { - return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4") - } + if len(v6Networks) > 0 { + return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses") + } - if len(v6UnsafeNetworks) > 0 { - return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4") - } + if len(v6UnsafeNetworks) > 0 { + return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") } t := &cert.TBSCertificate{ @@ -321,9 +325,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } crts = append(crts, nc) - } - if version == 0 || version == cert.Version2 { + case cert.Version2: t := &cert.TBSCertificate{ Version: cert.Version2, Name: *sf.name, @@ -351,6 +354,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } crts = append(crts, nc) + default: + // this should be unreachable + return fmt.Errorf("invalid version: %d", version) } if !isP11 && *sf.inPubPath == "" { diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index b2bba762..f5f8cbb0 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -55,7 +55,7 @@ func Test_signHelp(t *testing.T) { " -unsafe-networks string\n"+ " \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+ " -version uint\n"+ - " \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n", + " \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n", ob.String(), ) } @@ -204,7 +204,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -379,6 +379,15 @@ func Test_signCert(t *testing.T) { assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) + // test with the proper password in the environment + os.Remove(crtF.Name()) + os.Remove(keyF.Name()) + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase)) + require.NoError(t, signCert(args, ob, eb, testpw)) + assert.Empty(t, eb.String()) + os.Setenv("NEBULA_CA_PASSPHRASE", "") + // test with the wrong password ob.Reset() eb.Reset() @@ -389,6 +398,17 @@ func Test_signCert(t *testing.T) { assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) + // test with the wrong password in environment + ob.Reset() + eb.Reset() + + os.Setenv("NEBULA_CA_PASSPHRASE", "invalid password") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing encrypted ca-key: invalid passphrase or corrupt private key") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + os.Setenv("NEBULA_CA_PASSPHRASE", "") + // test with the user not entering a password ob.Reset() eb.Reset() diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 8d0eaa1d..9a17b947 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -4,6 +4,8 @@ import ( "flag" "fmt" "os" + "runtime/debug" + "strings" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" @@ -18,6 +20,17 @@ import ( // at compile-time. var Build string +func init() { + if Build == "" { + info, ok := debug.ReadBuildInfo() + if !ok { + return + } + + Build = strings.TrimPrefix(info.Main.Version, "v") + } +} + func main() { serviceFlag := flag.String("service", "", "Control the system service.") configPath := flag.String("config", "", "Path to either a file or directory to load configuration from") diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index 5cf0a028..ffdc15bf 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -4,6 +4,8 @@ import ( "flag" "fmt" "os" + "runtime/debug" + "strings" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" @@ -18,6 +20,17 @@ import ( // at compile-time. var Build string +func init() { + if Build == "" { + info, ok := debug.ReadBuildInfo() + if !ok { + return + } + + Build = strings.TrimPrefix(info.Main.Version, "v") + } +} + func main() { configPath := flag.String("config", "", "Path to either a file or directory to load configuration from") configTest := flag.Bool("test", false, "Test the config and print the end result. Non zero exit indicates a faulty config") diff --git a/config/config.go b/config/config.go index 55103245..0d1be128 100644 --- a/config/config.go +++ b/config/config.go @@ -17,7 +17,7 @@ import ( "dario.cat/mergo" "github.com/sirupsen/logrus" - "gopkg.in/yaml.v3" + "go.yaml.in/yaml/v3" ) type C struct { diff --git a/config/config_test.go b/config/config_test.go index ec5a4b0f..aba7b2a8 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -10,7 +10,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" + "go.yaml.in/yaml/v3" ) func TestConfig_Load(t *testing.T) { diff --git a/connection_manager.go b/connection_manager.go index 1f9b18b9..4c2f26ef 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -354,9 +354,8 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim if mainHostInfo { decision = tryRehandshake - } else { - if cm.shouldSwapPrimary(hostinfo, primary) { + if cm.shouldSwapPrimary(hostinfo) { decision = swapPrimary } else { // migrate the relays to the primary, if in use. @@ -447,7 +446,7 @@ func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time return inactiveDuration, true } -func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { +func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool { // The primary tunnel is the most recent handshake to complete locally and should work entirely fine. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. @@ -461,6 +460,10 @@ func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool } crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) + if crt == nil { + //my cert was reloaded away. We should definitely swap from this tunnel + return true + } // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things // settle down. return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) @@ -475,31 +478,34 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) { cm.hostMap.Unlock() } -// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and -// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid -// check and return true. +// isInvalidCertificate decides if we should destroy a tunnel. +// returns true if pki.disconnect_invalid is true and the certificate is no longer valid. +// Blocklisted certificates will skip the pki.disconnect_invalid check and return true. func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { remoteCert := hostinfo.GetCert() if remoteCert == nil { - return false + return false //don't tear down tunnels for handshakes in progress } caPool := cm.intf.pki.GetCAPool() err := caPool.VerifyCachedCertificate(now, remoteCert) if err == nil { - return false - } - - if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed { + return false //cert is still valid! yay! + } else if err == cert.ErrBlockListed { //avoiding errors.Is for speed // Block listed certificates should always be disconnected + hostinfo.logger(cm.l).WithError(err). + WithField("fingerprint", remoteCert.Fingerprint). + Info("Remote certificate is blocked, tearing down the tunnel") + return true + } else if cm.intf.disconnectInvalid.Load() { + hostinfo.logger(cm.l).WithError(err). + WithField("fingerprint", remoteCert.Fingerprint). + Info("Remote certificate is no longer valid, tearing down the tunnel") + return true + } else { + //if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open return false } - - hostinfo.logger(cm.l).WithError(err). - WithField("fingerprint", remoteCert.Fingerprint). - Info("Remote certificate is no longer valid, tearing down the tunnel") - - return true } func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { @@ -530,15 +536,45 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { cs := cm.intf.pki.getCertState() curCrt := hostinfo.ConnectionState.myCert - myCrt := cs.getCertificate(curCrt.Version()) - if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { - // The current tunnel is using the latest certificate and version, no need to rehandshake. + curCrtVersion := curCrt.Version() + myCrt := cs.getCertificate(curCrtVersion) + if myCrt == nil { + cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("version", curCrtVersion). + WithField("reason", "local certificate removed"). + Info("Re-handshaking with remote") + cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return } + peerCrt := hostinfo.ConnectionState.peerCert + if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() { + // if our certificate version is less than theirs, and we have a matching version available, rehandshake? + if cs.getCertificate(peerCrt.Certificate.Version()) != nil { + cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("version", curCrtVersion). + WithField("peerVersion", peerCrt.Certificate.Version()). + WithField("reason", "local certificate version lower than peer, attempting to correct"). + Info("Re-handshaking with remote") + cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) { + hh.initiatingVersionOverride = peerCrt.Certificate.Version() + }) + return + } + } + if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) { + cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("reason", "local certificate is not current"). + Info("Re-handshaking with remote") - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("reason", "local certificate is not current"). - Info("Re-handshaking with remote") + cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) + return + } + if curCrtVersion < cs.initiatingVersion { + cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("reason", "current cert version < pki.initiatingVersion"). + Info("Re-handshaking with remote") - cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) + cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) + return + } } diff --git a/connection_manager_test.go b/connection_manager_test.go index ecd28801..647dd72b 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -22,7 +22,7 @@ func newTestLighthouse() *LightHouse { addrMap: map[netip.Addr]*RemoteList{}, queryChan: make(chan netip.Addr, 10), } - lighthouses := map[netip.Addr]struct{}{} + lighthouses := []netip.Addr{} staticList := map[netip.Addr]struct{}{} lh.lighthouses.Store(&lighthouses) @@ -446,6 +446,10 @@ func (d *dummyCert) PublicKey() []byte { return d.publicKey } +func (d *dummyCert) MarshalPublicKeyPEM() []byte { + return cert.MarshalPublicKeyToPEM(d.curve, d.publicKey) +} + func (d *dummyCert) Signature() []byte { return d.signature } diff --git a/connection_state.go b/connection_state.go index faee443d..db885d42 100644 --- a/connection_state.go +++ b/connection_state.go @@ -50,11 +50,6 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i } static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} - - b := NewBits(ReplayWindow) - // Clear out bit 0, we never transmit it, and we don't want it showing as packet loss - b.Update(l, 0) - hs, err := noise.NewHandshakeState(noise.Config{ CipherSuite: ncs, Random: rand.Reader, @@ -74,7 +69,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i ci := &ConnectionState{ H: hs, initiator: initiator, - window: b, + window: NewBits(ReplayWindow), myCert: crt, } // always start the counter from 2, as packet 1 and packet 2 are handshake packets. diff --git a/control_tester.go b/control_tester.go index 451dac53..7403a745 100644 --- a/control_tester.go +++ b/control_tester.go @@ -174,6 +174,10 @@ func (c *Control) GetHostmap() *HostMap { return c.f.hostMap } +func (c *Control) GetF() *Interface { + return c.f +} + func (c *Control) GetCertState() *CertState { return c.f.pki.getCertState() } diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 53d37386..67b166b1 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -20,16 +20,17 @@ import ( "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" + "go.yaml.in/yaml/v3" ) func BenchmarkHotPath(b *testing.B) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Start the servers myControl.Start() @@ -38,6 +39,9 @@ func BenchmarkHotPath(b *testing.B) { r := router.NewR(b, myControl, theirControl) r.CancelFlowLogs() + assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + b.ResetTimer() + for n := 0; n < b.N; n++ { myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) @@ -47,6 +51,39 @@ func BenchmarkHotPath(b *testing.B) { theirControl.Stop() } +func BenchmarkHotPathRelay(b *testing.B) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(b, myControl, relayControl, theirControl) + r.CancelFlowLogs() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) + b.ResetTimer() + + for n := 0; n < b.N; n++ { + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + _ = r.RouteForAllUntilTxTun(theirControl) + } + + myControl.Stop() + theirControl.Stop() + relayControl.Stop() +} + func TestGoodHandshake(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) @@ -97,6 +134,41 @@ func TestGoodHandshake(t *testing.T) { theirControl.Stop() } +func TestGoodHandshakeNoOverlap(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack! + + // Put their info in our lighthouse + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + empty := []byte{} + t.Log("do something to cause a handshake") + myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty) + + t.Log("Have them consume my stage 0 packet. They have a tunnel now") + theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) + + t.Log("Get their stage 1 packet") + stage1Packet := theirControl.GetFromUDP(true) + + t.Log("Have me consume their stage 1 packet. I have a tunnel now") + myControl.InjectUDPPacket(stage1Packet) + + t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete") + myControl.WaitForType(header.Test, 0, theirControl) + + t.Log("Make sure our host infos are correct") + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + + myControl.Stop() + theirControl.Stop() +} + func TestWrongResponderHandshake(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -464,6 +536,35 @@ func TestRelays(t *testing.T) { r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) } +func TestRelaysDontCareAboutIps(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) +} + func TestReestablishRelays(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) @@ -1227,3 +1328,109 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) { myControl.Stop() theirControl.Stop() } + +func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}}) + + o := m{ + "static_host_map": m{ + lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()}, + }, + "lighthouse": m{ + "hosts": []string{lhVpnIpNet[0].Addr().String()}, + "local_allow_list": m{ + // Try and block our lighthouse updates from using the actual addresses assigned to this computer + // If we start discovering addresses the test router doesn't know about then test traffic cant flow + "10.0.0.0/24": true, + "::/0": false, + }, + }, + } + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o) + theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, lhControl, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + lhControl.Start() + myControl.Start() + theirControl.Start() + + t.Log("Stand up an ipv6 tunnel between me and them") + assert.True(t, myVpnIpNet[1].Addr().Is6()) + assert.True(t, theirVpnIpNet[1].Addr().Is6()) + assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r) + + lhControl.Stop() + myControl.Stop() + theirControl.Stop() +} + +func TestGoodHandshakeUnsafeDest(t *testing.T) { + unsafePrefix := "192.168.6.0/24" + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil) + route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()} + myCfg := m{ + "tun": m{ + "unsafe_routes": []m{route}, + }, + } + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg) + t.Logf("my config %v", myConfig) + // Put their info in our lighthouse + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + spookyDest := netip.MustParseAddr("192.168.6.4") + + // Start the servers + myControl.Start() + theirControl.Start() + + t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") + myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + t.Log("Have them consume my stage 0 packet. They have a tunnel now") + theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) + + t.Log("Get their stage 1 packet so that we can play with it") + stage1Packet := theirControl.GetFromUDP(true) + + t.Log("I consume a garbage packet with a proper nebula header for our tunnel") + // this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel + badPacket := stage1Packet.Copy() + badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len] + myControl.InjectUDPPacket(badPacket) + + t.Log("Have me consume their real stage 1 packet. I have a tunnel now") + myControl.InjectUDPPacket(stage1Packet) + + t.Log("Wait until we see my cached packet come through") + myControl.WaitForType(1, 0, theirControl) + + t.Log("Make sure our host infos are correct") + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + + t.Log("Get that cached packet and make sure it looks right") + myCachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80) + + //reply + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman")) + //wait for reply + theirControl.WaitForType(1, 0, myControl) + theirCachedPacket := myControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80) + + t.Log("Do a bidirectional tunnel test") + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + myControl.Stop() + theirControl.Stop() +} diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index a63b3d01..7a802c99 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -22,15 +22,14 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" "github.com/stretchr/testify/assert" - "gopkg.in/yaml.v3" + "github.com/stretchr/testify/require" + "go.yaml.in/yaml/v3" ) type m = map[string]any // newSimpleServer creates a nebula instance with many assumptions func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { - l := NewTestLogger() - var vpnNetworks []netip.Prefix for _, sn := range strings.Split(sVpnNetworks, ",") { vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) @@ -56,7 +55,54 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name budpIp[3] = 239 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } - _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}) + return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides) +} + +func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { + return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides) +} + +func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { + l := NewTestLogger() + + var vpnNetworks []netip.Prefix + for _, sn := range strings.Split(sVpnNetworks, ",") { + vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) + if err != nil { + panic(err) + } + vpnNetworks = append(vpnNetworks, vpnIpNet) + } + + if len(vpnNetworks) == 0 { + panic("no vpn networks") + } + + firewallInbound := []m{{ + "proto": "any", + "port": "any", + "host": "any", + }} + + var unsafeNetworks []netip.Prefix + if sUnsafeNetworks != "" { + firewallInbound = []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "0.0.0.0/0", + }} + + for _, sn := range strings.Split(sUnsafeNetworks, ",") { + x, err := netip.ParsePrefix(strings.TrimSpace(sn)) + if err != nil { + panic(err) + } + unsafeNetworks = append(unsafeNetworks, x) + } + } + + _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{}) caB, err := caCrt.MarshalPEM() if err != nil { @@ -76,11 +122,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name "port": "any", "host": "any", }}, - "inbound": []m{{ - "proto": "any", - "port": "any", - "host": "any", - }}, + "inbound": firewallInbound, }, //"handshakes": m{ // "try_interval": "1s", @@ -129,6 +171,109 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name return control, vpnNetworks, udpAddr, c } +// newServer creates a nebula instance with fewer assumptions +func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { + l := NewTestLogger() + + vpnNetworks := certs[len(certs)-1].Networks() + + var udpAddr netip.AddrPort + if vpnNetworks[0].Addr().Is4() { + budpIp := vpnNetworks[0].Addr().As4() + budpIp[1] -= 128 + udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) + } else { + budpIp := vpnNetworks[0].Addr().As16() + // beef for funsies + budpIp[2] = 190 + budpIp[3] = 239 + udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) + } + + caStr := "" + for _, ca := range caCrt { + x, err := ca.MarshalPEM() + if err != nil { + panic(err) + } + caStr += string(x) + } + certStr := "" + for _, c := range certs { + x, err := c.MarshalPEM() + if err != nil { + panic(err) + } + certStr += string(x) + } + + mc := m{ + "pki": m{ + "ca": caStr, + "cert": certStr, + "key": string(key), + }, + //"tun": m{"disabled": true}, + "firewall": m{ + "outbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + }, + //"handshakes": m{ + // "try_interval": "1s", + //}, + "listen": m{ + "host": udpAddr.Addr().String(), + "port": udpAddr.Port(), + }, + "logging": m{ + "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()), + "level": l.Level.String(), + }, + "timers": m{ + "pending_deletion_interval": 2, + "connection_alive_interval": 2, + }, + } + + if overrides != nil { + final := m{} + err := mergo.Merge(&final, overrides, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + err = mergo.Merge(&final, mc, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + mc = final + } + + cb, err := yaml.Marshal(mc) + if err != nil { + panic(err) + } + + c := config.NewC(l) + cStr := string(cb) + c.LoadString(cStr) + + control, err := nebula.Main(c, false, "e2e-test", l, nil) + + if err != nil { + panic(err) + } + + return control, vpnNetworks, udpAddr, c +} + type doneCb func() func deadline(t *testing.T, seconds time.Duration) doneCb { @@ -147,7 +292,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { } } -func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { +func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) bPacket := r.RouteForAllUntilTxTun(controlA) @@ -163,10 +308,10 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn // Get both host infos //TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) - assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") + require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false) - assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") + require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") // Check that both vpn and real addr are correct assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A") @@ -180,7 +325,7 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index") } -func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { +func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { if toIp.Is6() { assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort) } else { @@ -188,7 +333,7 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, } } -func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { +func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy) v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) assert.NotNil(t, v6, "No ipv6 data found") @@ -207,7 +352,7 @@ func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, assert.Equal(t, expected, data.Payload(), "Data was incorrect") } -func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { +func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) assert.NotNil(t, v4, "No ipv4 data found") diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index 55974f0f..e89cf869 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -4,12 +4,16 @@ package e2e import ( + "fmt" + "net/netip" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" ) func TestDropInactiveTunnels(t *testing.T) { @@ -55,3 +59,309 @@ func TestDropInactiveTunnels(t *testing.T) { myControl.Stop() theirControl.Stop() } + +func TestCertUpgrade(t *testing.T) { + // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides + // under ideal conditions + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + caB, err := ca.MarshalPEM() + if err != nil { + panic(err) + } + ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + ca2B, err := ca2.MarshalPEM() + if err != nil { + panic(err) + } + caStr := fmt.Sprintf("%s\n%s", caB, ca2B) + + myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{}) + _, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2) + + theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{}) + theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2) + + myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{}) + + // Share our underlay information + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + r.Log("Assert the tunnel between me and them works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + r.Log("yay") + //todo ??? + time.Sleep(1 * time.Second) + r.FlushAll() + + mc := m{ + "pki": m{ + "ca": caStr, + "cert": string(myCert2Pem), + "key": string(myPrivKey), + }, + //"tun": m{"disabled": true}, + "firewall": myC.Settings["firewall"], + //"handshakes": m{ + // "try_interval": "1s", + //}, + "listen": myC.Settings["listen"], + "logging": myC.Settings["logging"], + "timers": myC.Settings["timers"], + } + + cb, err := yaml.Marshal(mc) + if err != nil { + panic(err) + } + + r.Logf("reload new v2-only config") + err = myC.ReloadConfigString(string(cb)) + assert.NoError(t, err) + r.Log("yay, spin until their sees it") + waitStart := time.Now() + for { + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + if c == nil { + r.Log("nil") + } else { + version := c.Cert.Version() + r.Logf("version %d", version) + if version == cert.Version2 { + break + } + } + since := time.Since(waitStart) + if since > time.Second*10 { + t.Fatal("Cert should be new by now") + } + time.Sleep(time.Second) + } + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + + myControl.Stop() + theirControl.Stop() +} + +func TestCertDowngrade(t *testing.T) { + // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides + // under ideal conditions + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + caB, err := ca.MarshalPEM() + if err != nil { + panic(err) + } + ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + ca2B, err := ca2.MarshalPEM() + if err != nil { + panic(err) + } + caStr := fmt.Sprintf("%s\n%s", caB, ca2B) + + myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{}) + myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2) + + theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{}) + theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2) + + myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{}) + + // Share our underlay information + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + r.Log("Assert the tunnel between me and them works") + //assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) + //r.Log("yay") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + r.Log("yay") + //todo ??? + time.Sleep(1 * time.Second) + r.FlushAll() + + mc := m{ + "pki": m{ + "ca": caStr, + "cert": string(myCertPem), + "key": string(myPrivKey), + }, + "firewall": myC.Settings["firewall"], + "listen": myC.Settings["listen"], + "logging": myC.Settings["logging"], + "timers": myC.Settings["timers"], + } + + cb, err := yaml.Marshal(mc) + if err != nil { + panic(err) + } + + r.Logf("reload new v1-only config") + err = myC.ReloadConfigString(string(cb)) + assert.NoError(t, err) + r.Log("yay, spin until their sees it") + waitStart := time.Now() + for { + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + if c == nil || c2 == nil { + r.Log("nil") + } else { + version := c.Cert.Version() + theirVersion := c2.Cert.Version() + r.Logf("version %d,%d", version, theirVersion) + if version == cert.Version1 { + break + } + } + since := time.Since(waitStart) + if since > time.Second*5 { + r.Log("it is unusual that the cert is not new yet, but not a failure yet") + } + if since > time.Second*10 { + r.Log("wtf") + t.Fatal("Cert should be new by now") + } + time.Sleep(time.Second) + } + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + + myControl.Stop() + theirControl.Stop() +} + +func TestCertMismatchCorrection(t *testing.T) { + // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides + // under ideal conditions + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{}) + myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2) + + theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{}) + theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2) + + myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{}) + + // Share our underlay information + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + r.Log("Assert the tunnel between me and them works") + //assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) + //r.Log("yay") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + r.Log("yay") + //todo ??? + time.Sleep(1 * time.Second) + r.FlushAll() + + waitStart := time.Now() + for { + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + if c == nil || c2 == nil { + r.Log("nil") + } else { + version := c.Cert.Version() + theirVersion := c2.Cert.Version() + r.Logf("version %d,%d", version, theirVersion) + if version == theirVersion { + break + } + } + since := time.Since(waitStart) + if since > time.Second*5 { + r.Log("wtf") + } + if since > time.Second*10 { + r.Log("wtf") + t.Fatal("Cert should be new by now") + } + time.Sleep(time.Second) + } + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + + myControl.Stop() + theirControl.Stop() +} + +func TestCrossStackRelaysWork(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}}) + theirUdp := netip.MustParseAddrPort("10.0.0.2:4242") + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}}) + + //myVpnV4 := myVpnIpNet[0] + myVpnV6 := myVpnIpNet[1] + relayVpnV4 := relayVpnIpNet[0] + relayVpnV6 := relayVpnIpNet[1] + theirVpnV6 := theirVpnIpNet[0] + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80) + + t.Log("reply?") + theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them")) + p = r.RouteForAllUntilTxTun(myControl) + assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80) + + r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) + //t.Log("finish up") + //myControl.Stop() + //theirControl.Stop() + //relayControl.Stop() +} diff --git a/examples/config.yml b/examples/config.yml index 1831be37..b98b32cc 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -424,8 +424,9 @@ firewall: # host: `any` or a literal hostname, ie `test-host` # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass - # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. - # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This can be used to filter destinations when using unsafe_routes. + # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address. + # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. `any` means any ip family and address. + # This can be used to filter destinations when using unsafe_routes. # By default, this is set to only the VPN (overlay) networks assigned via the certificate networks field unless `default_local_cidr_any` is set to true. # If there are unsafe_routes present in this config file, `local_cidr` should be set appropriately for the intended us case. # ca_name: An issuing CA name diff --git a/firewall.go b/firewall.go index 971c156d..45dc0691 100644 --- a/firewall.go +++ b/firewall.go @@ -8,6 +8,7 @@ import ( "hash/fnv" "net/netip" "reflect" + "slices" "strconv" "strings" "sync" @@ -22,7 +23,7 @@ import ( ) type FirewallInterface interface { - AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error + AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error } type conn struct { @@ -247,22 +248,11 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew } // AddRule properly creates the in memory rule structure for a firewall table. -func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error { - // Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS - // https://github.com/golang/go/issues/14131 - sIp := "" - if ip.IsValid() { - sIp = ip.String() - } - lIp := "" - if localIp.IsValid() { - lIp = localIp.String() - } - +func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { // We need this rule string because we generate a hash. Removing this will break firewall reload. ruleString := fmt.Sprintf( "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, localIp: %v, caName: %v, caSha: %s", - incoming, proto, startPort, endPort, groups, host, sIp, lIp, caName, caSha, + incoming, proto, startPort, endPort, groups, host, cidr, localCidr, caName, caSha, ) f.rules += ruleString + "\n" @@ -270,7 +260,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort if !incoming { direction = "outgoing" } - f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "localIp": lIp, "caName": caName, "caSha": caSha}). + f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "cidr": cidr, "localCidr": localCidr, "caName": caName, "caSha": caSha}). Info("Firewall rule added") var ( @@ -297,7 +287,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort return fmt.Errorf("unknown protocol %v", proto) } - return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha) + return fp.addRule(f, startPort, endPort, groups, host, cidr, localCidr, caName, caSha) } // GetRuleHash returns a hash representation of all inbound and outbound rules @@ -337,7 +327,6 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw } for i, t := range rs { - var groups []string r, err := convertRule(l, t, table, i) if err != nil { return fmt.Errorf("%s rule #%v; %s", table, i, err) @@ -347,23 +336,10 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i) } - if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" { + if r.Host == "" && len(r.Groups) == 0 && r.Cidr == "" && r.LocalCidr == "" && r.CAName == "" && r.CASha == "" { return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided", table, i) } - if len(r.Groups) > 0 { - groups = r.Groups - } - - if r.Group != "" { - // Check if we have both groups and group provided in the rule config - if len(groups) > 0 { - return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i) - } - - groups = []string{r.Group} - } - var sPort, errPort string if r.Code != "" { errPort = "code" @@ -392,23 +368,25 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto) } - var cidr netip.Prefix - if r.Cidr != "" { - cidr, err = netip.ParsePrefix(r.Cidr) + if r.Cidr != "" && r.Cidr != "any" { + _, err = netip.ParsePrefix(r.Cidr) if err != nil { return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err) } } - var localCidr netip.Prefix - if r.LocalCidr != "" { - localCidr, err = netip.ParsePrefix(r.LocalCidr) + if r.LocalCidr != "" && r.LocalCidr != "any" { + _, err = netip.ParsePrefix(r.LocalCidr) if err != nil { return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err) } } - err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, localCidr, r.CAName, r.CASha) + if warning := r.sanity(); warning != nil { + l.Warnf("%s rule #%v; %s", table, i, warning) + } + + err = fw.AddRule(inbound, proto, startPort, endPort, r.Groups, r.Host, r.Cidr, r.LocalCidr, r.CAName, r.CASha) if err != nil { return fmt.Errorf("%s rule #%v; `%s`", table, i, err) } @@ -417,8 +395,10 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return nil } -var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets") -var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs") +var ErrUnknownNetworkType = errors.New("unknown network type") +var ErrPeerRejected = errors.New("remote address is not within a network that we handle") +var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks") +var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses") var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It @@ -429,18 +409,31 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * return nil } - // Make sure remote address matches nebula certificate - if h.networks != nil { - if !h.networks.Contains(fp.RemoteAddr) { - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrInvalidRemoteIP - } - } else { + // Make sure remote address matches nebula certificate, and determine how to treat it + if h.networks == nil { // Simple case: Certificate has one address and no unsafe networks if h.vpnAddrs[0] != fp.RemoteAddr { f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } + } else { + nwType, ok := h.networks.Lookup(fp.RemoteAddr) + if !ok { + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrInvalidRemoteIP + } + switch nwType { + case NetworkTypeVPN: + break // nothing special + case NetworkTypeVPNPeer: + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrPeerRejected // reject for now, one day this may have different FW rules + case NetworkTypeUnsafe: + break // nothing special, one day this may have different FW rules + default: + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrUnknownNetworkType //should never happen + } } // Make sure we are supposed to be handling this local ip address @@ -640,7 +633,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC return false } -func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error { +func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, cidr, localCidr, caName string, caSha string) error { if startPort > endPort { return fmt.Errorf("start port was lower than end port") } @@ -653,7 +646,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou } } - if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil { + if err := fp[i].addRule(f, groups, host, cidr, localCidr, caName, caSha); err != nil { return err } } @@ -684,7 +677,7 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer return fp[firewall.PortAny].match(p, c, caPool) } -func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error { +func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, cidr, localCidr, caName, caSha string) error { fr := func() *FirewallRule { return &FirewallRule{ Hosts: make(map[string]*firewallLocalCIDR), @@ -698,14 +691,14 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc fc.Any = fr() } - return fc.Any.addRule(f, groups, host, ip, localIp) + return fc.Any.addRule(f, groups, host, cidr, localCidr) } if caSha != "" { if _, ok := fc.CAShas[caSha]; !ok { fc.CAShas[caSha] = fr() } - err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp) + err := fc.CAShas[caSha].addRule(f, groups, host, cidr, localCidr) if err != nil { return err } @@ -715,7 +708,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc if _, ok := fc.CANames[caName]; !ok { fc.CANames[caName] = fr() } - err := fc.CANames[caName].addRule(f, groups, host, ip, localIp) + err := fc.CANames[caName].addRule(f, groups, host, cidr, localCidr) if err != nil { return err } @@ -747,24 +740,24 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool return fc.CANames[s.Certificate.Name()].match(p, c) } -func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error { +func (fr *FirewallRule) addRule(f *Firewall, groups []string, host, cidr, localCidr string) error { flc := func() *firewallLocalCIDR { return &firewallLocalCIDR{ LocalCIDR: new(bart.Lite), } } - if fr.isAny(groups, host, ip) { + if fr.isAny(groups, host, cidr) { if fr.Any == nil { fr.Any = flc() } - return fr.Any.addRule(f, localCIDR) + return fr.Any.addRule(f, localCidr) } if len(groups) > 0 { nlc := flc() - err := nlc.addRule(f, localCIDR) + err := nlc.addRule(f, localCidr) if err != nil { return err } @@ -780,30 +773,34 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, l if nlc == nil { nlc = flc() } - err := nlc.addRule(f, localCIDR) + err := nlc.addRule(f, localCidr) if err != nil { return err } fr.Hosts[host] = nlc } - if ip.IsValid() { - nlc, _ := fr.CIDR.Get(ip) - if nlc == nil { - nlc = flc() - } - err := nlc.addRule(f, localCIDR) + if cidr != "" { + c, err := netip.ParsePrefix(cidr) if err != nil { return err } - fr.CIDR.Insert(ip, nlc) + nlc, _ := fr.CIDR.Get(c) + if nlc == nil { + nlc = flc() + } + err = nlc.addRule(f, localCidr) + if err != nil { + return err + } + fr.CIDR.Insert(c, nlc) } return nil } -func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool { - if len(groups) == 0 && host == "" && !ip.IsValid() { +func (fr *FirewallRule) isAny(groups []string, host string, cidr string) bool { + if len(groups) == 0 && host == "" && cidr == "" { return true } @@ -817,7 +814,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) boo return true } - if ip.IsValid() && ip.Bits() == 0 { + if cidr == "any" { return true } @@ -869,8 +866,13 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool return false } -func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { - if !localIp.IsValid() { +func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error { + if localCidr == "any" { + flc.Any = true + return nil + } + + if localCidr == "" { if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { flc.Any = true return nil @@ -881,12 +883,13 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { } return nil - } else if localIp.Bits() == 0 { - flc.Any = true - return nil } - flc.LocalCIDR.Insert(localIp) + c, err := netip.ParsePrefix(localCidr) + if err != nil { + return err + } + flc.LocalCIDR.Insert(c) return nil } @@ -907,7 +910,6 @@ type rule struct { Code string Proto string Host string - Group string Groups []string Cidr string LocalCidr string @@ -949,7 +951,8 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i) m["group"] = v[0] } - r.Group = toString("group", m) + + singleGroup := toString("group", m) if rg, ok := m["groups"]; ok { switch reflect.TypeOf(rg).Kind() { @@ -966,9 +969,60 @@ func convertRule(l *logrus.Logger, p any, table string, i int) (rule, error) { } } + //flatten group vs groups + if singleGroup != "" { + // Check if we have both groups and group provided in the rule config + if len(r.Groups) > 0 { + return r, fmt.Errorf("only one of group or groups should be defined, both provided") + } + r.Groups = []string{singleGroup} + } + return r, nil } +// sanity returns an error if the rule would be evaluated in a way that would short-circuit a configured check on a wildcard value +// rules are evaluated as "port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND local_cidr" +func (r *rule) sanity() error { + //port, proto, local_cidr are AND, no need to check here + //ca_sha and ca_name don't have a wildcard value, no need to check here + groupsEmpty := len(r.Groups) == 0 + hostEmpty := r.Host == "" + cidrEmpty := r.Cidr == "" + + if (groupsEmpty && hostEmpty && cidrEmpty) == true { + return nil //no content! + } + + groupsHasAny := slices.Contains(r.Groups, "any") + if groupsHasAny && len(r.Groups) > 1 { + return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the other groups specified", r.Groups) + } + + if r.Host == "any" { + if !groupsEmpty { + return fmt.Errorf("groups specified as %s, but host=any will match any host, regardless of groups", r.Groups) + } + + if !cidrEmpty { + return fmt.Errorf("cidr specified as %s, but host=any will match any host, regardless of cidr", r.Cidr) + } + } + + if groupsHasAny { + if !hostEmpty && r.Host != "any" { + return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified host %s", r.Groups, r.Host) + } + if !cidrEmpty { + return fmt.Errorf("groups spec [%s] contains the group '\"any\". This rule will ignore the specified cidr %s", r.Groups, r.Cidr) + } + } + + //todo alert on cidr-any + + return nil +} + func parsePort(s string) (startPort, endPort int32, err error) { if s == "any" { startPort = firewall.PortAny diff --git a/firewall_test.go b/firewall_test.go index 4731a6ff..1df62a81 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "github.com/gaissmai/bart" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" @@ -68,66 +70,117 @@ func TestFirewall_AddRule(t *testing.T) { ti, err := netip.ParsePrefix("1.2.3.4/32") require.NoError(t, err) - require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + ti6, err := netip.ParsePrefix("fd12::34/128") + require.NoError(t, err) + + require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) - assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) - _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) + assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) + _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) + assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) + ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) + assert.True(t, ok) + + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) + assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) + ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) + assert.True(t, ok) + + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) anyIp, err := netip.ParsePrefix("0.0.0.0/0") require.NoError(t, err) - require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", "")) + assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) + table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1")) + assert.True(t, table.Any) + table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9")) + assert.False(t, ok) + + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + anyIp6, err := netip.ParsePrefix("::/0") + require.NoError(t, err) + + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", "")) + assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any) + table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9")) + assert.True(t, table.Any) + table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1")) + assert.False(t, ok) + + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", "")) + assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) + + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", "")) + assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) + assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) + assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) + + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", "")) + assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) + assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) + assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) + + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) - require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) + require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) } func TestFirewall_Drop(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) - + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"), @@ -152,10 +205,10 @@ func TestFirewall_Drop(t *testing.T) { }, vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, } - h.buildNetworks(c.networks, c.unsafeNetworks) + h.buildNetworks(myVpnNetworksTable, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() // Drop outbound @@ -174,28 +227,107 @@ func TestFirewall_Drop(t *testing.T) { // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) +} + +func TestFirewall_DropV6(t *testing.T) { + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) + + p := firewall.Packet{ + LocalAddr: netip.MustParseAddr("fd12::34"), + RemoteAddr: netip.MustParseAddr("fd12::34"), + LocalPort: 10, + RemotePort: 90, + Protocol: firewall.ProtoUDP, + Fragment: false, + } + + c := dummyCert{ + name: "host1", + networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: &c, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")}, + } + h.buildNetworks(myVpnNetworksTable, &c) + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + cp := cert.NewCAPool() + + // Drop outbound + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) + // Allow inbound + resetConntrack(fw) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + // Allow outbound because conntrack + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + + // test remote mismatch + oldRemote := p.RemoteAddr + p.RemoteAddr = netip.MustParseAddr("fd12::56") + assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) + p.RemoteAddr = oldRemote + + // ensure signer doesn't get in the way of group checks + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) + assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + + // test caSha doesn't drop on match + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + + // ensure ca name doesn't get in the way of group checks + cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) + assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + + // test caName doesn't drop on match + cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } @@ -206,8 +338,12 @@ func BenchmarkFirewallTable_match(b *testing.B) { } pfix := netip.MustParsePrefix("172.1.1.1/32") - _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "") - _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "") + _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "") + _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "") + + pfix6 := netip.MustParsePrefix("fd11::11/128") + _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "") + _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "") cp := cert.NewCAPool() b.Run("fail on proto", func(b *testing.B) { @@ -239,6 +375,15 @@ func BenchmarkFirewallTable_match(b *testing.B) { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp)) } }) + b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) { + c := &cert.CachedCertificate{ + Certificate: &dummyCert{}, + } + ip := netip.MustParsePrefix("fd99::99/128") + for n := 0; n < b.N; n++ { + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp)) + } + }) b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) { c := &cert.CachedCertificate{ @@ -252,6 +397,18 @@ func BenchmarkFirewallTable_match(b *testing.B) { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) } }) + b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) { + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", + networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")}, + }, + InvertedGroups: map[string]struct{}{"nope": {}}, + } + for n := 0; n < b.N; n++ { + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)) + } + }) b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) { c := &cert.CachedCertificate{ @@ -265,6 +422,18 @@ func BenchmarkFirewallTable_match(b *testing.B) { assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) + b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) { + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", + networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")}, + }, + InvertedGroups: map[string]struct{}{"nope": {}}, + } + for n := 0; n < b.N; n++ { + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp)) + } + }) b.Run("pass on group on any local cidr", func(b *testing.B) { c := &cert.CachedCertificate{ @@ -289,6 +458,17 @@ func BenchmarkFirewallTable_match(b *testing.B) { assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) + b.Run("pass on group on specific local cidr6", func(b *testing.B) { + c := &cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "nope", + }, + InvertedGroups: map[string]struct{}{"good-group": {}}, + } + for n := 0; n < b.N; n++ { + assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp)) + } + }) b.Run("pass on name", func(b *testing.B) { c := &cert.CachedCertificate{ @@ -307,6 +487,8 @@ func TestFirewall_Drop2(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), @@ -332,7 +514,7 @@ func TestFirewall_Drop2(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) + h.buildNetworks(myVpnNetworksTable, c.Certificate) c1 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -347,10 +529,10 @@ func TestFirewall_Drop2(t *testing.T) { peerCert: &c1, }, } - h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) + h1.buildNetworks(myVpnNetworksTable, c1.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups @@ -364,6 +546,8 @@ func TestFirewall_Drop3(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), @@ -395,7 +579,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) + h1.buildNetworks(myVpnNetworksTable, c1.Certificate) c2 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -410,7 +594,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks()) + h2.buildNetworks(myVpnNetworksTable, c2.Certificate) c3 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -425,11 +609,11 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) + h3.buildNetworks(myVpnNetworksTable, c3.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match @@ -443,14 +627,54 @@ func TestFirewall_Drop3(t *testing.T) { // Test a remote address match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) } +func TestFirewall_Drop3V6(t *testing.T) { + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) + + p := firewall.Packet{ + LocalAddr: netip.MustParseAddr("fd12::34"), + RemoteAddr: netip.MustParseAddr("fd12::34"), + LocalPort: 1, + RemotePort: 1, + Protocol: firewall.ProtoUDP, + Fragment: false, + } + + network := netip.MustParsePrefix("fd12::34/120") + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host-owner", + networks: []netip.Prefix{network}, + }, + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c, + }, + vpnAddrs: []netip.Addr{network.Addr()}, + } + h.buildNetworks(myVpnNetworksTable, c.Certificate) + + // Test a remote address match + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + cp := cert.NewCAPool() + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) +} + func TestFirewall_DropConntrackReload(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), @@ -477,10 +701,10 @@ func TestFirewall_DropConntrackReload(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) + h.buildNetworks(myVpnNetworksTable, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() // Drop outbound @@ -493,7 +717,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -502,7 +726,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -510,6 +734,52 @@ func TestFirewall_DropConntrackReload(t *testing.T) { assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) } +func TestFirewall_DropIPSpoofing(t *testing.T) { + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24")) + + c := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host-owner", + networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")}, + }, + } + + c1 := cert.CachedCertificate{ + Certificate: &dummyCert{ + name: "host", + networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")}, + unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")}, + }, + } + h1 := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &c1, + }, + vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()}, + } + h1.buildNetworks(myVpnNetworksTable, c1.Certificate) + + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) + cp := cert.NewCAPool() + + // Packet spoofed by `c1`. Note that the remote addr is not a valid one. + p := firewall.Packet{ + LocalAddr: netip.MustParseAddr("192.0.2.1"), + RemoteAddr: netip.MustParseAddr("192.0.2.3"), + LocalPort: 1, + RemotePort: 1, + Protocol: firewall.ProtoUDP, + Fragment: false, + } + assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP) +} + func BenchmarkLookup(b *testing.B) { ml := func(m map[string]struct{}, a [][]string) { for n := 0; n < b.N; n++ { @@ -689,28 +959,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf := &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) - assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall) // Test adding rule with cidr cidr := netip.MustParsePrefix("10.0.0.0/8") @@ -718,49 +988,90 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall) + + // Test adding rule with cidr ipv6 + cidr6 := netip.MustParsePrefix("fd00::/8") + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall) + + // Test adding rule with any cidr + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall) + + // Test adding rule with junk cidr + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}} + require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") + + // Test adding rule with local_cidr ipv6 + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall) + + // Test adding rule with any local_cidr + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}} + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall) + + // Test adding rule with junk local_cidr + conf = config.NewC(l) + mf = &mockFirewall{} + conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}} + require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP") // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) - assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) + assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall) // Test Add error conf = config.NewC(l) @@ -783,7 +1094,7 @@ func TestFirewall_convertRule(t *testing.T) { r, err := convertRule(l, c, "test", 1) assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") require.NoError(t, err) - assert.Equal(t, "group1", r.Group) + assert.Equal(t, []string{"group1"}, r.Groups) // Ensure group array of > 1 is errord ob.Reset() @@ -803,7 +1114,228 @@ func TestFirewall_convertRule(t *testing.T) { r, err = convertRule(l, c, "test", 1) require.NoError(t, err) - assert.Equal(t, "group1", r.Group) + assert.Equal(t, []string{"group1"}, r.Groups) +} + +func TestFirewall_convertRuleSanity(t *testing.T) { + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + + noWarningPlease := []map[string]any{ + {"group": "group1"}, + {"groups": []any{"group2"}}, + {"host": "bob"}, + {"cidr": "1.1.1.1/1"}, + {"groups": []any{"group2"}, "host": "bob"}, + {"cidr": "1.1.1.1/1", "host": "bob"}, + {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"}, + {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"}, + } + for _, c := range noWarningPlease { + r, err := convertRule(l, c, "test", 1) + require.NoError(t, err) + require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c) + } + + yesWarningPlease := []map[string]any{ + {"group": "group1"}, + {"groups": []any{"group2"}}, + {"cidr": "1.1.1.1/1"}, + {"groups": []any{"group2"}, "host": "bob"}, + {"cidr": "1.1.1.1/1", "host": "bob"}, + {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"}, + {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"}, + } + for _, c := range yesWarningPlease { + c["host"] = "any" + r, err := convertRule(l, c, "test", 1) + require.NoError(t, err) + err = r.sanity() + require.Error(t, err, "I wanted a warning: %+v", c) + } + //reset the list + yesWarningPlease = []map[string]any{ + {"group": "group1"}, + {"groups": []any{"group2"}}, + {"cidr": "1.1.1.1/1"}, + {"groups": []any{"group2"}, "host": "bob"}, + {"cidr": "1.1.1.1/1", "host": "bob"}, + {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"}, + {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"}, + } + for _, c := range yesWarningPlease { + r, err := convertRule(l, c, "test", 1) + require.NoError(t, err) + r.Groups = append(r.Groups, "any") + err = r.sanity() + require.Error(t, err, "I wanted a warning: %+v", c) + } +} + +type testcase struct { + h *HostInfo + p firewall.Packet + c cert.Certificate + err error +} + +func (c *testcase) Test(t *testing.T, fw *Firewall) { + t.Helper() + cp := cert.NewCAPool() + resetConntrack(fw) + err := fw.Drop(c.p, true, c.h, cp, nil) + if c.err == nil { + require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr) + } else { + require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr) + } +} + +func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase { + c1 := dummyCert{ + name: "host1", + networks: theirPrefixes, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: &c1, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: make([]netip.Addr, len(theirPrefixes)), + } + for i := range theirPrefixes { + h.vpnAddrs[i] = theirPrefixes[i].Addr() + } + h.buildNetworks(setup.myVpnNetworksTable, &c1) + p := firewall.Packet{ + LocalAddr: setup.c.Networks()[0].Addr(), //todo? + RemoteAddr: theirPrefixes[0].Addr(), + LocalPort: 10, + RemotePort: 90, + Protocol: firewall.ProtoUDP, + Fragment: false, + } + return testcase{ + h: &h, + p: p, + c: &c1, + err: err, + } +} + +type testsetup struct { + c dummyCert + myVpnNetworksTable *bart.Lite + fw *Firewall +} + +func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup { + c := dummyCert{ + name: "me", + networks: myPrefixes, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + return newSetupFromCert(t, l, c) +} + +func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { + myVpnNetworksTable := new(bart.Lite) + for _, prefix := range c.Networks() { + myVpnNetworksTable.Insert(prefix) + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) + + return testsetup{ + c: c, + fw: fw, + myVpnNetworksTable: myVpnNetworksTable, + } +} + +func TestFirewall_Drop_EnforceIPMatch(t *testing.T) { + t.Parallel() + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + + myPrefix := netip.MustParsePrefix("1.1.1.1/8") + // for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out + t.Run("allow inbound all matching", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24")) + tc.Test(t, setup.fw) + }) + t.Run("allow inbound local matching", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24")) + tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8") + tc.Test(t, setup.fw) + }) + t.Run("block inbound remote mismatched", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24")) + tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9") + tc.Test(t, setup.fw) + }) + t.Run("Block a vpn peer packet", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24")) + tc.Test(t, setup.fw) + }) + twoPrefixes := []netip.Prefix{ + netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"), + } + t.Run("allow inbound one matching", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, nil, twoPrefixes...) + tc.Test(t, setup.fw) + }) + t.Run("block inbound multimismatch", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...) + tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9") + tc.Test(t, setup.fw) + }) + t.Run("allow inbound 2nd one matching", func(t *testing.T) { + t.Parallel() + setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24")) + tc := buildTestCase(setup2, nil, twoPrefixes...) + tc.p.RemoteAddr = twoPrefixes[1].Addr() + tc.Test(t, setup2.fw) + }) + t.Run("allow inbound unsafe route", func(t *testing.T) { + t.Parallel() + unsafePrefix := netip.MustParsePrefix("192.168.0.0/24") + c := dummyCert{ + name: "me", + networks: []netip.Prefix{myPrefix}, + unsafeNetworks: []netip.Prefix{unsafePrefix}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + unsafeSetup := newSetupFromCert(t, l, c) + tc := buildTestCase(unsafeSetup, nil, twoPrefixes...) + tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3") + tc.err = ErrNoMatchingRule + tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off + require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", "")) + tc.err = nil + tc.Test(t, unsafeSetup.fw) //should pass + }) } type addRuleCall struct { @@ -813,8 +1345,8 @@ type addRuleCall struct { endPort int32 groups []string host string - ip netip.Prefix - localIp netip.Prefix + ip string + localIp string caName string caSha string } @@ -824,7 +1356,7 @@ type mockFirewall struct { nextCallReturn error } -func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error { +func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error { mf.lastCall = addRuleCall{ incoming: incoming, proto: proto, diff --git a/go.mod b/go.mod index d552a7c1..e927eb8b 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/slackhq/nebula -go 1.23.0 - -toolchain go1.24.1 +go 1.25 require ( dario.cat/mergo v1.0.2 @@ -10,30 +8,31 @@ require ( github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 - github.com/gaissmai/bart v0.20.4 + github.com/gaissmai/bart v0.26.0 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 - github.com/kardianos/service v1.2.2 - github.com/miekg/dns v1.1.65 + github.com/kardianos/service v1.2.4 + github.com/miekg/dns v1.1.68 github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b github.com/nbrownus/go-metrics-prometheus v0.0.0-20210712211119-974a6260965f - github.com/prometheus/client_golang v1.22.0 + github.com/prometheus/client_golang v1.23.2 github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 - github.com/stretchr/testify v1.10.0 + github.com/stretchr/testify v1.11.1 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.37.0 + go.yaml.in/yaml/v3 v3.0.4 + golang.org/x/crypto v0.45.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.39.0 - golang.org/x/sync v0.13.0 - golang.org/x/sys v0.32.0 - golang.org/x/term v0.31.0 + golang.org/x/net v0.47.0 + golang.org/x/sync v0.18.0 + golang.org/x/sys v0.38.0 + golang.org/x/term v0.37.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/protobuf v1.36.6 + google.golang.org/protobuf v1.36.10 gopkg.in/yaml.v3 v3.0.1 gvisor.dev/gvisor v0.0.0-20240423190808-9d7a357edefe ) @@ -45,11 +44,12 @@ require ( github.com/google/btree v1.1.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/prometheus/client_model v0.6.1 // indirect - github.com/prometheus/common v0.62.0 // indirect - github.com/prometheus/procfs v0.15.1 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.66.1 // indirect + github.com/prometheus/procfs v0.16.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect - golang.org/x/mod v0.23.0 // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect + golang.org/x/mod v0.24.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.30.0 // indirect + golang.org/x/tools v0.33.0 // indirect ) diff --git a/go.sum b/go.sum index a932e586..3679fac6 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/gaissmai/bart v0.20.4 h1:Ik47r1fy3jRVU+1eYzKSW3ho2UgBVTVnUS8O993584U= -github.com/gaissmai/bart v0.20.4/go.mod h1:cEed+ge8dalcbpi8wtS9x9m2hn/fNJH5suhdGQOHnYk= +github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0= +github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= @@ -64,8 +64,8 @@ github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/ github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= -github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60= -github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= +github.com/kardianos/service v1.2.4 h1:XNlGtZOYNx2u91urOdg/Kfmc+gfmuIo1Dd3rEi2OgBk= +github.com/kardianos/service v1.2.4/go.mod h1:E4V9ufUuY82F7Ztlu1eN9VXWIQxg8NoLQlmFe0MtrXc= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= @@ -83,8 +83,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc= -github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck= +github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA= +github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b h1:J/AzCvg5z0Hn1rqZUJjpbzALUmkKX0Zwbc/i4fw7Sfk= github.com/miekg/pkcs11 v1.1.2-0.20231115102856-9078ad6b9d4b/go.mod h1:XsNlhZGX73bx86s2hdc/FuaLm2CPZJemRLMA+WTFxgs= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -106,24 +106,24 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= -github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= -github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= -github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= +github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= +github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= -github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= @@ -143,29 +143,35 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= -golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= -golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= +golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -176,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= -golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -185,8 +191,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= -golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -197,18 +203,17 @@ golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= -golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= -golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= +golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -219,8 +224,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= -golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= +golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= +golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -239,8 +244,8 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/handshake_ix.go b/handshake_ix.go index 54b68bb3..25a0c371 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -2,7 +2,6 @@ package nebula import ( "net/netip" - "slices" "time" "github.com/flynn/noise" @@ -24,13 +23,17 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { return false } - // If we're connecting to a v6 address we must use a v2 cert cs := f.pki.getCertState() v := cs.initiatingVersion - for _, a := range hh.hostinfo.vpnAddrs { - if a.Is6() { - v = cert.Version2 - break + if hh.initiatingVersionOverride != cert.VersionPre1 { + v = hh.initiatingVersionOverride + } else if v < cert.Version2 { + // If we're connecting to a v6 address we should encourage use of a V2 cert + for _, a := range hh.hostinfo.vpnAddrs { + if a.Is6() { + v = cert.Version2 + break + } } } @@ -49,6 +52,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("certVersion", v). Error("Unable to handshake with host because no certificate handshake bytes is available") + return false } ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) @@ -105,19 +109,20 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { return true } -func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { +func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) { cs := f.pki.getCertState() crt := cs.GetDefaultCertificate() if crt == nil { - f.l.WithField("udpAddr", addr). + f.l.WithField("from", via). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("certVersion", cs.initiatingVersion). Error("Unable to handshake with host because no certificate is available") + return } ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) if err != nil { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Error("Failed to create connection state") return @@ -128,7 +133,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Error("Failed to call noise.ReadMessage") return @@ -137,7 +142,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Error("Failed unmarshal handshake message") return @@ -145,7 +150,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("Handshake did not contain a certificate") return @@ -153,12 +158,12 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) if err != nil { - fp, err := rc.Fingerprint() - if err != nil { + fp, fperr := rc.Fingerprint() + if fperr != nil { fp = "" } - e := f.l.WithError(err).WithField("udpAddr", addr). + e := f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("certVpnNetworks", rc.Networks()). WithField("certFingerprint", fp) @@ -173,37 +178,40 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if remoteCert.Certificate.Version() != ci.myCert.Version() { // We started off using the wrong certificate version, lets see if we can match the version that was sent to us - rc := cs.getCertificate(remoteCert.Certificate.Version()) - if rc == nil { - f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). - Info("Unable to handshake with host due to missing certificate version") - return + myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version()) + if myCertOtherVersion == nil { + if f.l.Level >= logrus.DebugLevel { + f.l.WithError(err).WithFields(m{ + "from": via, + "handshake": m{"stage": 1, "style": "ix_psk0"}, + "cert": remoteCert, + }).Debug("Might be unable to handshake with host due to missing certificate version") + } + } else { + // Record the certificate we are actually using + ci.myCert = myCertOtherVersion } - - // Record the certificate we are actually using - ci.myCert = rc } if len(remoteCert.Certificate.Networks()) == 0 { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("cert", remoteCert). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("No networks in certificate") return } - var vpnAddrs []netip.Addr - var filteredNetworks []netip.Prefix certName := remoteCert.Certificate.Name() certVersion := remoteCert.Certificate.Version() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() + vpnNetworks := remoteCert.Certificate.Networks() - for _, network := range remoteCert.Certificate.Networks() { - vpnAddr := network.Addr() - if f.myVpnAddrsTable.Contains(vpnAddr) { - f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). + anyVpnAddrsInCommon := false + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + for i, network := range vpnNetworks { + if f.myVpnAddrsTable.Contains(network.Addr()) { + f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -211,38 +219,24 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") return } - - // vpnAddrs outside our vpn networks are of no use to us, filter them out - if !f.myVpnNetworksTable.Contains(vpnAddr) { - continue + vpnAddrs[i] = network.Addr() + if f.myVpnNetworksTable.Contains(network.Addr()) { + anyVpnAddrsInCommon = true } - - filteredNetworks = append(filteredNetworks, network) - vpnAddrs = append(vpnAddrs, vpnAddr) } - if len(vpnAddrs) == 0 { - f.l.WithError(err).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") - return - } - - if addr.IsValid() { - // addr can be invalid when the tunnel is being relayed. + if !via.IsRelayed { // We only want to apply the remote allow list for direct tunnels here - if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) { - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { + f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). + Debug("lighthouse.remote_allow_list denied incoming handshake") return } } myIndex, err := generateIndex(f.l) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -265,10 +259,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet TotalPorts: uint32(f.multiPort.TxPorts), } } - if hs.Details.InitiatorMultiPort != nil && hs.Details.InitiatorMultiPort.BasePort != uint32(addr.Port()) { + if hs.Details.InitiatorMultiPort != nil && hs.Details.InitiatorMultiPort.BasePort != uint32(via.UdpAddr.Port()) { // The other side sent us a handshake from a different port, make sure // we send responses back to the BasePort - addr = netip.AddrPortFrom(addr.Addr(), uint16(hs.Details.InitiatorMultiPort.BasePort)) + via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), uint16(hs.Details.InitiatorMultiPort.BasePort)) } hostinfo := &HostInfo{ @@ -287,27 +281,32 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet }, } - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("multiportTx", multiportTx).WithField("multiportRx", multiportRx). - Info("Handshake message received") + msgRxL := f.l.WithFields(m{ + "vpnAddrs": vpnAddrs, + "from": via, + "certName": certName, + "certVersion": certVersion, + "fingerprint": fingerprint, + "issuer": issuer, + "initiatorIndex": hs.Details.InitiatorIndex, + "responderIndex": hs.Details.ResponderIndex, + "remoteIndex": h.RemoteIndex, + "multiportTx": multiportTx, + "multiportRx": multiportRx, + "handshake": m{"stage": 1, "style": "ix_psk0"}, + }) + + if anyVpnAddrsInCommon { + msgRxL.Info("Handshake message received") + } else { + //todo warn if not lighthouse or relay? + msgRxL.Info("Handshake message received, but no vpnNetworks in common.") + } hs.Details.ResponderIndex = myIndex hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) if hs.Details.Cert == nil { - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("certVersion", ci.myCert.Version()). + msgRxL.WithField("myCertVersion", ci.myCert.Version()). Error("Unable to handshake with host because no certificate handshake bytes is available") return } @@ -318,7 +317,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -330,7 +329,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -338,7 +337,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return } else if dKey == nil || eKey == nil { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -364,8 +363,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet ci.eKey = NewNebulaCipherState(eKey) hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) - hostinfo.SetRemote(addr) - hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) + if !via.IsRelayed { + hostinfo.SetRemote(via.UdpAddr) + } + hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { @@ -373,10 +374,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet case ErrAlreadySeen: if hostinfo.multiportRx { // The other host is sending to us with multiport, so only grab the IP - addr = netip.AddrPortFrom(addr.Addr(), hostinfo.remote.Port()) + via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), hostinfo.remote.Port()) } // Update remote if preferred - if existing.SetRemoteIfPreferred(f.hostMap, addr) { + if existing.SetRemoteIfPreferred(f.hostMap, via) { // Send a test packet to ensure the other side has also switched to // the preferred remote f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) @@ -384,28 +385,29 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet msg = existing.HandshakePacket[2] f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr.IsValid() { + if !via.IsRelayed { + err := f.outside.WriteTo(msg, via.UdpAddr) if multiportTx { // TODO remove alloc here raw := make([]byte, len(msg)+udp.RawOverhead) copy(raw[udp.RawOverhead:], msg) - err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), addr) + err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), via.UdpAddr) } else { - err = f.outside.WriteTo(msg, addr) + err = f.outside.WriteTo(msg, via.UdpAddr) } if err != nil { - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithError(err).Error("Failed to send handshake message") } else { - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") } return } else { - if via == nil { - f.l.Error("Handshake send failed: both addr and via are nil.") + if via.relay == nil { + f.l.Error("Handshake send failed: both addr and via.relay are nil.") return } hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) @@ -417,7 +419,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } case ErrExistingHostInfo: // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("oldHandshakeTime", existing.lastHandshakeTime). @@ -433,7 +435,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -446,7 +448,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here - f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -460,37 +462,30 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet // Do the send f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr.IsValid() { + if !via.IsRelayed { if multiportTx { // TODO remove alloc here raw := make([]byte, len(msg)+udp.RawOverhead) copy(raw[udp.RawOverhead:], msg) - err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), addr) + err = f.udpRaw.WriteTo(raw, udp.RandomSendPort.UDPSendPort(f.multiPort.TxPorts), via.UdpAddr) } else { - err = f.outside.WriteTo(msg, addr) + err = f.outside.WriteTo(msg, via.UdpAddr) } + log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). + WithField("certName", certName). + WithField("certVersion", certVersion). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) if err != nil { - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithError(err).Error("Failed to send handshake") + log.WithError(err).Error("Failed to send handshake") } else { - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Handshake message sent") + log.Info("Handshake message sent") } } else { - if via == nil { - f.l.Error("Handshake send failed: both addr and via are nil.") + if via.relay == nil { + f.l.Error("Handshake send failed: both addr and via.relay are nil.") return } hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) @@ -510,12 +505,12 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.connectionManager.AddTrafficWatch(hostinfo) - hostinfo.remotes.ResetBlockedRemotes() + hostinfo.remotes.RefreshFromHandshake(vpnAddrs) return } -func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { +func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { if hh == nil { // Nothing here to tear down, got a bogus stage 2 packet return true @@ -525,10 +520,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha defer hh.Unlock() hostinfo := hh.hostinfo - if addr.IsValid() { + if !via.IsRelayed { // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. - if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } } @@ -536,7 +531,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci := hostinfo.ConnectionState msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). Error("Failed to call noise.ReadMessage") @@ -545,7 +540,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // near future return false } else if dKey == nil || eKey == nil { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Noise did not arrive at a key") @@ -557,7 +552,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again @@ -569,18 +564,18 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hostinfo.multiportRx = hs.Details.ResponderMultiPort.TxSupported && f.multiPort.Rx } - if hs.Details.ResponderMultiPort != nil && hs.Details.ResponderMultiPort.BasePort != uint32(addr.Port()) { + if hs.Details.ResponderMultiPort != nil && hs.Details.ResponderMultiPort.BasePort != uint32(via.UdpAddr.Port()) { // The other side sent us a handshake from a different port, make sure // we send responses back to the BasePort - addr = netip.AddrPortFrom( - addr.Addr(), + via.UdpAddr = netip.AddrPortFrom( + via.UdpAddr.Addr(), uint16(hs.Details.ResponderMultiPort.BasePort), ) } rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Info("Handshake did not contain a certificate") @@ -594,7 +589,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha fp = "" } - e := f.l.WithError(err).WithField("udpAddr", addr). + e := f.l.WithError(err).WithField("from", via). WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("certFingerprint", fp). @@ -609,7 +604,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha } if len(remoteCert.Certificate.Networks()) == 0 { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("cert", remoteCert). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). @@ -632,39 +627,30 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci.eKey = NewNebulaCipherState(eKey) // Make sure the current udpAddr being used is set for responding - if addr.IsValid() { - hostinfo.SetRemote(addr) + if !via.IsRelayed { + hostinfo.SetRemote(via.UdpAddr) } else { hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) } - var vpnAddrs []netip.Addr - var filteredNetworks []netip.Prefix - for _, network := range vpnNetworks { - // vpnAddrs outside our vpn networks are of no use to us, filter them out - vpnAddr := network.Addr() - if !f.myVpnNetworksTable.Contains(vpnAddr) { - continue + correctHostResponded := false + anyVpnAddrsInCommon := false + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + for i, network := range vpnNetworks { + vpnAddrs[i] = network.Addr() + if f.myVpnNetworksTable.Contains(network.Addr()) { + anyVpnAddrsInCommon = true + } + if hostinfo.vpnAddrs[0] == network.Addr() { + // todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not? + correctHostResponded = true } - - filteredNetworks = append(filteredNetworks, network) - vpnAddrs = append(vpnAddrs, vpnAddr) - } - - if len(vpnAddrs) == 0 { - f.l.WithError(err).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") - return true } // Ensure the right host responded - if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) { + if !correctHostResponded { f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). - WithField("udpAddr", addr). + WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). @@ -674,10 +660,11 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha f.handshakeManager.DeleteHostInfo(hostinfo) // Create a new hostinfo/handshake for the intended vpn ip + //TODO is hostinfo.vpnAddrs[0] always the address to use? f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { // Block the current used address newHH.hostinfo.remotes = hostinfo.remotes - newHH.hostinfo.remotes.BlockRemote(addr) + newHH.hostinfo.remotes.BlockRemote(via) f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). WithField("vpnNetworks", vpnNetworks). @@ -700,7 +687,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci.window.Update(f.l, 2) duration := time.Since(hh.startTime).Nanoseconds() - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -709,12 +696,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("durationNs", duration). WithField("sentCachedPackets", len(hh.packetStore)). - WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx). - Info("Handshake message received") + WithField("multiportTx", hostinfo.multiportTx).WithField("multiportRx", hostinfo.multiportRx) + if anyVpnAddrsInCommon { + msgRxL.Info("Handshake message received") + } else { + //todo warn if not lighthouse or relay? + msgRxL.Info("Handshake message received, but no vpnNetworks in common.") + } // Build up the radix for the firewall if we have subnets in the cert hostinfo.vpnAddrs = vpnAddrs - hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) + hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here f.handshakeManager.Complete(hostinfo, f) @@ -733,7 +725,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) } - hostinfo.remotes.ResetBlockedRemotes() + hostinfo.remotes.RefreshFromHandshake(vpnAddrs) f.metricHandshakes.Update(duration) return false diff --git a/handshake_manager.go b/handshake_manager.go index 19c9d593..1e9a0956 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -71,11 +71,12 @@ type HandshakeManager struct { type HandshakeHostInfo struct { sync.Mutex - startTime time.Time // Time that we first started trying with this handshake - ready bool // Is the handshake ready - counter int64 // How many attempts have we made so far - lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt - packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes + startTime time.Time // Time that we first started trying with this handshake + ready bool // Is the handshake ready + initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake? + counter int64 // How many attempts have we made so far + lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt + packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes hostinfo *HostInfo } @@ -138,11 +139,11 @@ func (hm *HandshakeManager) Run(ctx context.Context) { } } -func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { +func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) { // First remote allow list check before we know the vpnIp - if addr.IsValid() { - if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) { - hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + if !via.IsRelayed { + if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) { + hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") return } } @@ -151,11 +152,11 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, case header.HandshakeIXPSK0: switch h.MessageCounter { case 1: - ixHandshakeStage1(hm.f, addr, via, packet, h) + ixHandshakeStage1(hm.f, via, packet, h) case 2: newHostinfo := hm.queryIndex(h.RemoteIndex) - tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h) + tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h) if tearDown && newHostinfo != nil { hm.DeleteHostInfo(newHostinfo.hostinfo) } @@ -294,12 +295,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { - // Don't relay to myself + // Don't relay through the host I'm trying to connect to if relay == vpnIp { continue } - // Don't relay through the host I'm trying to connect to + // Don't relay to myself if hm.f.myVpnAddrsTable.Contains(relay) { continue } diff --git a/hostmap.go b/hostmap.go index 4827d574..de863d87 100644 --- a/hostmap.go +++ b/hostmap.go @@ -1,7 +1,9 @@ package nebula import ( + "encoding/json" "errors" + "fmt" "net" "net/netip" "slices" @@ -17,12 +19,10 @@ import ( "github.com/slackhq/nebula/header" ) -// const ProbeLen = 100 const defaultPromoteEvery = 1000 // Count of packets sent before we try moving a tunnel to a preferred underlay ip address const defaultReQueryEvery = 5000 // Count of packets sent before re-querying a hostinfo to the lighthouse const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery const MaxRemotes = 10 -const maxRecvError = 4 // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip // 5 allows for an initial handshake and each host pair re-handshaking twice @@ -214,6 +214,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.relayForByIdx[idx] = r } +type NetworkType uint8 + +const ( + NetworkTypeUnknown NetworkType = iota + // NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate + NetworkTypeVPN + // NetworkTypeVPNPeer is a network that does not overlap one of our networks + NetworkTypeVPNPeer + // NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks() + NetworkTypeUnsafe +) + type HostInfo struct { remote netip.AddrPort remotes *RemoteList @@ -225,11 +237,10 @@ type HostInfo struct { // vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks // The host may have other vpn addresses that are outside our // vpn networks but were removed because they are not usable - vpnAddrs []netip.Addr - recvError atomic.Uint32 + vpnAddrs []netip.Addr - // networks are both all vpn and unsafe networks assigned to this host - networks *bart.Lite + // networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host. + networks *bart.Table[NetworkType] relayState RelayState // If true, we should send to this remote using multiport @@ -273,9 +284,25 @@ type HostInfo struct { } type ViaSender struct { + UdpAddr netip.AddrPort relayHI *HostInfo // relayHI is the host info object of the relay remoteIdx uint32 // remoteIdx is the index included in the header of the received packet relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us. + IsRelayed bool // IsRelayed is true if the packet was sent through a relay +} + +func (v ViaSender) String() string { + if v.IsRelayed { + return fmt.Sprintf("%s (relayed)", v.UdpAddr) + } + return v.UdpAddr.String() +} + +func (v ViaSender) MarshalJSON() ([]byte, error) { + if v.IsRelayed { + return json.Marshal(m{"relay": v.UdpAddr}) + } + return json.Marshal(m{"direct": v.UdpAddr}) } type cachedPacket struct { @@ -691,6 +718,7 @@ func (i *HostInfo) GetCert() *cert.CachedCertificate { return nil } +// TODO: Maybe use ViaSender here? func (i *HostInfo) SetRemote(remote netip.AddrPort) { // We copy here because we likely got this remote from a source that reuses the object if i.remote != remote { @@ -701,14 +729,14 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) { // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // time on the HostInfo will also be updated. -func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool { - if !newRemote.IsValid() { - // relays have nil udp Addrs +func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via ViaSender) bool { + if via.IsRelayed { return false } + currentRemote := i.remote if !currentRemote.IsValid() { - i.SetRemote(newRemote) + i.SetRemote(via.UdpAddr) return true } @@ -721,7 +749,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b return false } - if l.Contains(newRemote.Addr()) { + if l.Contains(via.UdpAddr.Addr()) { newIsPreferred = true } } @@ -731,7 +759,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b i.lastRoam = time.Now() i.lastRoamRemote = currentRemote - i.SetRemote(newRemote) + i.SetRemote(via.UdpAddr) return true } @@ -739,26 +767,26 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b return false } -func (i *HostInfo) RecvErrorExceeded() bool { - if i.recvError.Add(1) >= maxRecvError { - return true - } - return true -} - -func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) { - if len(networks) == 1 && len(unsafeNetworks) == 0 { - // Simple case, no CIDRTree needed - return +// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up. +func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) { + if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { + if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) { + return // Simple case, no BART needed + } } - i.networks = new(bart.Lite) - for _, network := range networks { - i.networks.Insert(network) + i.networks = new(bart.Table[NetworkType]) + for _, network := range c.Networks() { + nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) + if myVpnNetworksTable.Contains(network.Addr()) { + i.networks.Insert(nprefix, NetworkTypeVPN) + } else { + i.networks.Insert(nprefix, NetworkTypeVPNPeer) + } } - for _, network := range unsafeNetworks { - i.networks.Insert(network) + for _, network := range c.UnsafeNetworks() { + i.networks.Insert(network, NetworkTypeUnsafe) } } diff --git a/inside.go b/inside.go index f1fbe27b..feab01c3 100644 --- a/inside.go +++ b/inside.go @@ -121,9 +121,10 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q, nil) } -// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established +// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established +// it does not check if it is within our vpn networks! func (f *Interface) Handshake(vpnAddr netip.Addr) { - f.getOrHandshakeNoRouting(vpnAddr, nil) + f.handshakeManager.GetOrHandshake(vpnAddr, nil) } // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. @@ -139,7 +140,6 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu // getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel. func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - destinationAddr := fwPacket.RemoteAddr hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback) @@ -232,9 +232,10 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0, nil) } -// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr +// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr. +// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { - hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) { + hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) diff --git a/interface.go b/interface.go index 6d21391c..395f725b 100644 --- a/interface.go +++ b/interface.go @@ -234,6 +234,13 @@ func (f *Interface) activate() { WithField("boringcrypto", boringEnabled()). Info("Nebula interface is active") + if f.routines > 1 { + if !f.inside.SupportsMultiqueue() || !f.outside.SupportsMultipleReaders() { + f.routines = 1 + f.l.Warn("routines is not supported on this platform, falling back to a single routine") + } + } + metrics.GetOrRegisterGauge("routines", nil).Update(int64(f.routines)) metrics.GetOrRegisterGauge("multiport.tx_ports", nil).Update(int64(f.multiPort.TxPorts)) @@ -286,7 +293,7 @@ func (f *Interface) listenOut(i int) { nb := make([]byte, 12, 12) li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) }) } diff --git a/lighthouse.go b/lighthouse.go index 7a679c76..1510b942 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -24,6 +24,7 @@ import ( ) var ErrHostNotKnown = errors.New("host not known") +var ErrBadDetailsVpnAddr = errors.New("invalid packet, malformed detailsVpnAddr") type LightHouse struct { //TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time @@ -56,7 +57,7 @@ type LightHouse struct { // staticList exists to avoid having a bool in each addrMap entry // since static should be rare staticList atomic.Pointer[map[netip.Addr]struct{}] - lighthouses atomic.Pointer[map[netip.Addr]struct{}] + lighthouses atomic.Pointer[[]netip.Addr] interval atomic.Int64 updateCancel context.CancelFunc @@ -107,7 +108,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), l: l, } - lighthouses := make(map[netip.Addr]struct{}) + lighthouses := make([]netip.Addr, 0) h.lighthouses.Store(&lighthouses) staticList := make(map[netip.Addr]struct{}) h.staticList.Store(&staticList) @@ -143,7 +144,7 @@ func (lh *LightHouse) GetStaticHostList() map[netip.Addr]struct{} { return *lh.staticList.Load() } -func (lh *LightHouse) GetLighthouses() map[netip.Addr]struct{} { +func (lh *LightHouse) GetLighthouses() []netip.Addr { return *lh.lighthouses.Load() } @@ -306,13 +307,12 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } if initial || c.HasChanged("lighthouse.hosts") { - lhMap := make(map[netip.Addr]struct{}) - err := lh.parseLighthouses(c, lhMap) + lhList, err := lh.parseLighthouses(c) if err != nil { return err } - lh.lighthouses.Store(&lhMap) + lh.lighthouses.Store(&lhList) if !initial { //NOTE: we are not tearing down existing lighthouse connections because they might be used for non lighthouse traffic lh.l.Info("lighthouse.hosts has changed") @@ -346,36 +346,38 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return nil } -func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{}) error { +func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) { lhs := c.GetStringSlice("lighthouse.hosts", []string{}) if lh.amLighthouse && len(lhs) != 0 { lh.l.Warn("lighthouse.am_lighthouse enabled on node but upstream lighthouses exist in config") } + out := make([]netip.Addr, len(lhs)) for i, host := range lhs { addr, err := netip.ParseAddr(host) if err != nil { - return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) + return nil, util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } if !lh.myVpnNetworksTable.Contains(addr) { - return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) + lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}). + Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not") } - lhMap[addr] = struct{}{} + out[i] = addr } - if !lh.amLighthouse && len(lhMap) == 0 { + if !lh.amLighthouse && len(out) == 0 { lh.l.Warn("No lighthouse.hosts configured, this host will only be able to initiate tunnels with static_host_map entries") } staticList := lh.GetStaticHostList() - for lhAddr, _ := range lhMap { - if _, ok := staticList[lhAddr]; !ok { - return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr) + for i := range out { + if _, ok := staticList[out[i]]; !ok { + return nil, fmt.Errorf("lighthouse %s does not have a static_host_map entry", out[i]) } } - return nil + return out, nil } func getStaticMapCadence(c *config.C) (time.Duration, error) { @@ -430,7 +432,8 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc } if !lh.myVpnNetworksTable.Contains(vpnAddr) { - return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) + lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}). + Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work") } vals, ok := v.([]any) @@ -486,7 +489,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList { lh.Lock() defer lh.Unlock() // Add an entry if we don't already have one - return lh.unlockedGetRemoteList(vpnAddrs) + return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip } // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing @@ -519,11 +522,15 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in } func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { - // First we check the static mapping - // and do nothing if it is there - if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok { - return + // First we check the static host map. If any of the VpnAddrs to be deleted are present, do nothing. + staticList := lh.GetStaticHostList() + for _, addr := range allVpnAddrs { + if _, ok := staticList[addr]; ok { + return + } } + + // None of the VpnAddrs were present. Now we can do the deletes. lh.Lock() rm, ok := lh.addrMap[allVpnAddrs[0]] if ok { @@ -565,7 +572,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t am.unlockedSetHostnamesResults(hr) for _, addrPort := range hr.GetAddrs() { - if !lh.shouldAdd(vpnAddr, addrPort.Addr()) { + if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) { continue } switch { @@ -627,23 +634,30 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool { return len(calculatedV4) > 0 || len(calculatedV6) > 0 } -// unlockedGetRemoteList -// assumes you have the lh lock +// unlockedGetRemoteList assumes you have the lh lock func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { - am, ok := lh.addrMap[allAddrs[0]] - if !ok { - am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) }) - for _, addr := range allAddrs { - lh.addrMap[addr] = am + // before we go and make a new remotelist, we need to make sure we don't have one for any of this set of vpnaddrs yet + for i, addr := range allAddrs { + am, ok := lh.addrMap[addr] + if ok { + if i != 0 { + lh.addrMap[allAddrs[0]] = am + } + return am } } + + am := NewRemoteList(allAddrs, lh.shouldAdd) + for _, addr := range allAddrs { + lh.addrMap[addr] = am + } return am } -func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool { - allow := lh.GetRemoteAllowList().Allow(vpnAddr, to) +func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool { + allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow). + lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow). Trace("remoteAllowList.Allow") } if !allow { @@ -698,19 +712,22 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bo } func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { - if _, ok := lh.GetLighthouses()[vpnAddr]; ok { - return true + l := lh.GetLighthouses() + for i := range l { + if l[i] == vpnAddr { + return true + } } return false } -// TODO: CERT-V2 IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake -// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially -func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool { +func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool { l := lh.GetLighthouses() - for _, a := range vpnAddr { - if _, ok := l[a]; ok { - return true + for i := range vpnAddrs { + for j := range l { + if l[j] == vpnAddrs[i] { + return true + } } } return false @@ -752,7 +769,7 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { queried := 0 lighthouses := lh.GetLighthouses() - for lhVpnAddr := range lighthouses { + for _, lhVpnAddr := range lighthouses { hi := lh.ifce.GetHostInfo(lhVpnAddr) if hi != nil { v = hi.ConnectionState.myCert.Version() @@ -870,7 +887,7 @@ func (lh *LightHouse) SendUpdate() { updated := 0 lighthouses := lh.GetLighthouses() - for lhVpnAddr := range lighthouses { + for _, lhVpnAddr := range lighthouses { var v cert.Version hi := lh.ifce.GetHostInfo(lhVpnAddr) if hi != nil { @@ -928,7 +945,6 @@ func (lh *LightHouse) SendUpdate() { V4AddrPorts: v4, V6AddrPorts: v6, RelayVpnAddrs: relays, - VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()), }, } @@ -1048,19 +1064,19 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti return } - useVersion := cert.Version1 - var queryVpnAddr netip.Addr - if n.Details.OldVpnAddr != 0 { - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) - queryVpnAddr = netip.AddrFrom4(b) - useVersion = 1 - } else if n.Details.VpnAddr != nil { - queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) - useVersion = 2 - } else { + queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion() + if err != nil { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery") + lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details). + Debugln("Dropping malformed HostQuery") + } + return + } + if useVersion == cert.Version1 && queryVpnAddr.Is6() { + // this case really shouldn't be possible to represent, but reject it anyway. + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr). + Debugln("invalid vpn addr for v1 handleHostQuery") } return } @@ -1069,9 +1085,6 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply if useVersion == cert.Version1 { - if !queryVpnAddr.Is4() { - return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery") - } b := queryVpnAddr.As4() n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) } else { @@ -1116,8 +1129,9 @@ func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAd if ok { whereToPunch = newDest } else { - //TODO: CERT-V2 this means the destination will have no addresses in common with the punch-ee - //choosing to do nothing for now, but maybe we return an error? + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("to", crt.Networks()).Debugln("unable to punch to host, no addresses in common") + } } } @@ -1176,19 +1190,17 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul if !r.Is4() { continue } - b = r.As4() n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:])) } - } else if v == cert.Version2 { for _, r := range c.relay.relay { n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) } - } else { - //TODO: CERT-V2 don't panic - panic("unsupported version") + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("version", v).Debug("unsupported protocol version") + } } } } @@ -1198,18 +1210,16 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs [ return } - lhh.lh.Lock() - - var certVpnAddr netip.Addr - if n.Details.OldVpnAddr != 0 { - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) - certVpnAddr = netip.AddrFrom4(b) - } else if n.Details.VpnAddr != nil { - certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + certVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() + if err != nil { + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("dropping malformed HostQueryReply") + } + return } relays := n.Details.GetRelays() + lhh.lh.Lock() am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr}) am.Lock() lhh.lh.Unlock() @@ -1234,27 +1244,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp return } + // not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr var detailsVpnAddr netip.Addr - useVersion := cert.Version1 - if n.Details.OldVpnAddr != 0 { + var useVersion cert.Version + if n.Details.OldVpnAddr != 0 { //v1 always sets this field b := [4]byte{} binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) detailsVpnAddr = netip.AddrFrom4(b) useVersion = cert.Version1 - } else if n.Details.VpnAddr != nil { + } else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) useVersion = cert.Version2 } else { - if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification") - } - return + detailsVpnAddr = netip.Addr{} + useVersion = cert.Version2 } - //TODO: CERT-V2 hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4? - //TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right? - //Simple check that the host sent this not someone else - if !slices.Contains(fromVpnAddrs, detailsVpnAddr) { + //Simple check that the host sent this not someone else, if detailsVpnAddr is filled + if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) { if lhh.l.Level >= logrus.DebugLevel { lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") } @@ -1268,24 +1275,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) am.unlockedSetRelay(fromVpnAddrs[0], relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck - - if useVersion == cert.Version1 { + switch useVersion { + case cert.Version1: if !fromVpnAddrs[0].Is4() { lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") return } vpnAddrB := fromVpnAddrs[0].As4() n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:]) - } else if useVersion == cert.Version2 { - n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) - } else { + case cert.Version2: + // do nothing, we want to send a blank message + default: lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") return } @@ -1303,13 +1310,20 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { //It's possible the lighthouse is communicating with us using a non primary vpn addr, //which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs. - //maybe one day we'll have a better idea, if it matters. if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { return } + detailsVpnAddr, _, err := n.Details.GetVpnAddrAndVersion() + if err != nil { + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostPunchNotification") + } + return + } + empty := []byte{0} - punch := func(vpnPeer netip.AddrPort) { + punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) { if !vpnPeer.IsValid() { return } @@ -1321,48 +1335,38 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn }() if lhh.l.Level >= logrus.DebugLevel { - var logVpnAddr netip.Addr - if n.Details.OldVpnAddr != 0 { - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) - logVpnAddr = netip.AddrFrom4(b) - } else if n.Details.VpnAddr != nil { - logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) - } lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr) } } + remoteAllowList := lhh.lh.GetRemoteAllowList() for _, a := range n.Details.V4AddrPorts { - punch(protoV4AddrPortToNetAddrPort(a)) + b := protoV4AddrPortToNetAddrPort(a) + if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) { + punch(b, detailsVpnAddr) + } } for _, a := range n.Details.V6AddrPorts { - punch(protoV6AddrPortToNetAddrPort(a)) + b := protoV6AddrPortToNetAddrPort(a) + if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) { + punch(b, detailsVpnAddr) + } } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish // a tunnel. if lhh.lh.punchy.GetRespond() { - var queryVpnAddr netip.Addr - if n.Details.OldVpnAddr != 0 { - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) - queryVpnAddr = netip.AddrFrom4(b) - } else if n.Details.VpnAddr != nil { - queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) - } - go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr) + lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", detailsVpnAddr) } //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine // managed by a channel. - w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) }() } } @@ -1441,3 +1445,17 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, } return netip.Addr{}, false } + +func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, error) { + if d.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], d.OldVpnAddr) + detailsVpnAddr := netip.AddrFrom4(b) + return detailsVpnAddr, cert.Version1, nil + } else if d.VpnAddr != nil { + detailsVpnAddr := protoAddrToNetAddr(d.VpnAddr) + return detailsVpnAddr, cert.Version2, nil + } else { + return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr + } +} diff --git a/lighthouse_test.go b/lighthouse_test.go index eb2d26e4..fea1d1ed 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -14,7 +14,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" + "go.yaml.in/yaml/v3" ) func TestOldIPv4Only(t *testing.T) { @@ -493,3 +493,123 @@ func Test_findNetworkUnion(t *testing.T) { out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81}) assert.False(t, ok) } + +func TestLighthouse_Dont_Delete_Static_Hosts(t *testing.T) { + l := test.NewLogger() + + myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242") + + testSameHostNotStatic := netip.MustParseAddr("10.128.0.41") + testStaticHost := netip.MustParseAddr("10.128.0.42") + //myVpnIp := netip.MustParseAddr("10.128.0.2") + + c := config.NewC(l) + lh1 := "10.128.0.2" + c.Settings["lighthouse"] = map[string]any{ + "hosts": []any{lh1}, + "interval": "1s", + } + + c.Settings["listen"] = map[string]any{"port": 4242} + c.Settings["static_host_map"] = map[string]any{ + lh1: []any{"1.1.1.1:4242"}, + "10.128.0.42": []any{"1.2.3.4:4242"}, + } + + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Lite) + nt.Insert(myVpnNet) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + require.NoError(t, err) + lh.ifce = &mockEncWriter{} + + //test that we actually have the static entry: + out := lh.Query(testStaticHost) + assert.NotNil(t, out) + assert.Equal(t, out.vpnAddrs[0], testStaticHost) + out.Rebuild([]netip.Prefix{}) //why tho + assert.Equal(t, out.addrs[0], myUdpAddr2) + + //bolt on a lower numbered primary IP + am := lh.unlockedGetRemoteList([]netip.Addr{testStaticHost}) + am.vpnAddrs = []netip.Addr{testSameHostNotStatic, testStaticHost} + lh.addrMap[testSameHostNotStatic] = am + out.Rebuild([]netip.Prefix{}) //??? + + //test that we actually have the static entry: + out = lh.Query(testStaticHost) + assert.NotNil(t, out) + assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic) + assert.Equal(t, out.vpnAddrs[1], testStaticHost) + assert.Equal(t, out.addrs[0], myUdpAddr2) + + //test that we actually have the static entry for BOTH: + out2 := lh.Query(testSameHostNotStatic) + assert.Same(t, out2, out) + + //now do the delete + lh.DeleteVpnAddrs([]netip.Addr{testSameHostNotStatic, testStaticHost}) + //verify + out = lh.Query(testSameHostNotStatic) + assert.NotNil(t, out) + if out == nil { + t.Fatal("expected non-nil query for the static host") + } + assert.Equal(t, out.vpnAddrs[0], testSameHostNotStatic) + assert.Equal(t, out.vpnAddrs[1], testStaticHost) + assert.Equal(t, out.addrs[0], myUdpAddr2) +} + +func TestLighthouse_DeletesWork(t *testing.T) { + l := test.NewLogger() + + myUdpAddr2 := netip.MustParseAddrPort("1.2.3.4:4242") + testHost := netip.MustParseAddr("10.128.0.42") + + c := config.NewC(l) + lh1 := "10.128.0.2" + c.Settings["lighthouse"] = map[string]any{ + "hosts": []any{lh1}, + "interval": "1s", + } + + c.Settings["listen"] = map[string]any{"port": 4242} + c.Settings["static_host_map"] = map[string]any{ + lh1: []any{"1.1.1.1:4242"}, + } + + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Lite) + nt.Insert(myVpnNet) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + require.NoError(t, err) + lh.ifce = &mockEncWriter{} + + //insert the host + am := lh.unlockedGetRemoteList([]netip.Addr{testHost}) + am.vpnAddrs = []netip.Addr{testHost} + am.addrs = []netip.AddrPort{myUdpAddr2} + lh.addrMap[testHost] = am + am.Rebuild([]netip.Prefix{}) //??? + + //test that we actually have the entry: + out := lh.Query(testHost) + assert.NotNil(t, out) + assert.Equal(t, out.vpnAddrs[0], testHost) + out.Rebuild([]netip.Prefix{}) //why tho + assert.Equal(t, out.addrs[0], myUdpAddr2) + + //now do the delete + lh.DeleteVpnAddrs([]netip.Addr{testHost}) + //verify + out = lh.Query(testHost) + assert.Nil(t, out) +} diff --git a/main.go b/main.go index 14071ee0..2c2d9b82 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,8 @@ import ( "fmt" "net" "net/netip" + "runtime/debug" + "strings" "time" "github.com/sirupsen/logrus" @@ -13,7 +15,7 @@ import ( "github.com/slackhq/nebula/sshd" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" - "gopkg.in/yaml.v3" + "go.yaml.in/yaml/v3" ) type m = map[string]any @@ -27,6 +29,10 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } }() + if buildVersion == "" { + buildVersion = moduleVersion() + } + l := logger l.Formatter = &logrus.TextFormatter{ FullTimestamp: true, @@ -75,7 +81,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg if c.GetBool("sshd.enabled", false) { sshStart, err = configSSH(l, ssh, c) if err != nil { - return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err) + l.WithError(err).Warn("Failed to configure sshd, ssh debugging will not be available") + sshStart = nil } } @@ -328,3 +335,18 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg connManager.Start, }, nil } + +func moduleVersion() string { + info, ok := debug.ReadBuildInfo() + if !ok { + return "" + } + + for _, dep := range info.Deps { + if dep.Path == "github.com/slackhq/nebula" { + return strings.TrimPrefix(dep.Version, "v") + } + } + + return "" +} diff --git a/outside.go b/outside.go index 573b20f7..32ce1af2 100644 --- a/outside.go +++ b/outside.go @@ -19,21 +19,21 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { - f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err) + f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err) } return } //l.Error("in packet ", header, packet[HeaderLen:]) - if ip.IsValid() { - if f.myVpnNetworksTable.Contains(ip.Addr()) { + if !via.IsRelayed { + if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { if f.l.Level >= logrus.DebugLevel { - f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") + f.l.WithField("from", via).Debug("Refusing to process double encrypted packet") } return } @@ -54,8 +54,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] switch h.Type { case header.Message: - // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case. - if !f.handleEncrypted(ci, ip, h) { + if !f.handleEncrypted(ci, via, h) { return } @@ -79,7 +78,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] // Successfully validated the thing. Get rid of the Relay header. signedPayload = signedPayload[header.Len:] // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, ip) + f.handleHostRoaming(hostinfo, via) // Track usage of both the HostInfo and the Relay for the received & authenticated packet f.connectionManager.In(hostinfo) f.connectionManager.RelayUsed(h.RemoteIndex) @@ -96,7 +95,14 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] case TerminalType: // If I am the target of this relay, process the unwrapped packet // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + via = ViaSender{ + UdpAddr: via.UdpAddr, + relayHI: hostinfo, + remoteIdx: relay.RemoteIndex, + relay: relay, + IsRelayed: true, + } + f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) return case ForwardingType: // Find the target HostInfo relay object @@ -126,31 +132,32 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] case header.LightHouse: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, ip, h) { + if !f.handleEncrypted(ci, via, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + hostinfo.logger(f.l).WithError(err).WithField("from", via). WithField("packet", packet). Error("Failed to decrypt lighthouse packet") return } - lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) + //TODO: assert via is not relayed + lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f) // Fallthrough to the bottom to record incoming traffic case header.Test: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, ip, h) { + if !f.handleEncrypted(ci, via, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + hostinfo.logger(f.l).WithError(err).WithField("from", via). WithField("packet", packet). Error("Failed to decrypt test packet") return @@ -159,7 +166,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] if h.Subtype == header.TestRequest { // This testRequest might be from TryPromoteBest, so we should roam // to the new IP address before responding - f.handleHostRoaming(hostinfo, ip) + f.handleHostRoaming(hostinfo, via) f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) } @@ -170,34 +177,34 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] case header.Handshake: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(ip, via, packet, h) + f.handshakeManager.HandleIncoming(via, packet, h) return case header.RecvError: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(ip, h) + f.handleRecvError(via.UdpAddr, h) return case header.CloseTunnel: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, ip, h) { + if !f.handleEncrypted(ci, via, h) { return } - hostinfo.logger(f.l).WithField("udpAddr", ip). + hostinfo.logger(f.l).WithField("from", via). Info("Close tunnel received, tearing down.") f.closeTunnel(hostinfo) return case header.Control: - if !f.handleEncrypted(ci, ip, h) { + if !f.handleEncrypted(ci, via, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + hostinfo.logger(f.l).WithError(err).WithField("from", via). WithField("packet", packet). Error("Failed to decrypt Control packet") return @@ -207,11 +214,11 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip) + hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via) return } - f.handleHostRoaming(hostinfo, ip) + f.handleHostRoaming(hostinfo, via) f.connectionManager.In(hostinfo) } @@ -230,50 +237,51 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) { - if udpAddr.IsValid() && hostinfo.remote != udpAddr { +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { + if !via.IsRelayed && hostinfo.remote != via.UdpAddr { if hostinfo.multiportRx { // If the remote is sending with multiport, we aren't roaming unless // the IP has changed - if hostinfo.remote.Addr().Compare(udpAddr.Addr()) == 0 { + if hostinfo.remote.Addr().Compare(via.UdpAddr.Addr()) == 0 { return } // Keep the port from the original hostinfo, because the remote is transmitting from multiport ports - udpAddr = netip.AddrPortFrom(udpAddr.Addr(), hostinfo.remote.Port()) + via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), hostinfo.remote.Port()) } - - if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) { - hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming") + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { + hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming") return } - if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote - hostinfo.SetRemote(udpAddr) + hostinfo.SetRemote(via.UdpAddr) } } -func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool { - // If connectionstate exists and the replay protector allows, process packet - // Else, send recv errors for 300 seconds after a restart to allow fast reconnection. - if ci == nil || !ci.window.Check(f.l, h.MessageCounter) { - if addr.IsValid() { - f.maybeSendRecvError(addr, h.RemoteIndex) - return false - } else { - return false +// handleEncrypted returns true if a packet should be processed, false otherwise +func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool { + // If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect + if ci == nil { + if !via.IsRelayed { + f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) } + return false + } + // If the window check fails, refuse to process the packet, but don't send a recv error + if !ci.window.Check(f.l, h.MessageCounter) { + return false } return true @@ -547,10 +555,6 @@ func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { return } - if !hostinfo.RecvErrorExceeded() { - return - } - if hostinfo.remote.IsValid() && hostinfo.remote != addr { f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote) return diff --git a/overlay/device.go b/overlay/device.go index 07146ab3..b6077aba 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -13,5 +13,6 @@ type Device interface { Networks() []netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways + SupportsMultiqueue() bool NewMultiQueueReader() (io.ReadWriteCloser, error) } diff --git a/overlay/tun.go b/overlay/tun.go index 4a6377d2..3a61d186 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -1,6 +1,8 @@ package overlay import ( + "fmt" + "net" "net/netip" "github.com/sirupsen/logrus" @@ -70,3 +72,51 @@ func findRemovedRoutes(newRoutes, oldRoutes []Route) []Route { return removed } + +func prefixToMask(prefix netip.Prefix) netip.Addr { + pLen := 128 + if prefix.Addr().Is4() { + pLen = 32 + } + + addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen)) + return addr +} + +func flipBytes(b []byte) []byte { + for i := 0; i < len(b); i++ { + b[i] ^= 0xFF + } + return b +} +func orBytes(a []byte, b []byte) []byte { + ret := make([]byte, len(a)) + for i := 0; i < len(a); i++ { + ret[i] = a[i] | b[i] + } + return ret +} + +func getBroadcast(cidr netip.Prefix) netip.Addr { + broadcast, _ := netip.AddrFromSlice( + orBytes( + cidr.Addr().AsSlice(), + flipBytes(prefixToMask(cidr).AsSlice()), + ), + ) + return broadcast +} + +func selectGateway(dest netip.Prefix, gateways []netip.Prefix) (netip.Prefix, error) { + for _, gateway := range gateways { + if dest.Addr().Is4() && gateway.Addr().Is4() { + return gateway, nil + } + + if dest.Addr().Is6() && gateway.Addr().Is6() { + return gateway, nil + } + } + + return netip.Prefix{}, fmt.Errorf("no gateway found for %v in the list of vpn networks", dest) +} diff --git a/overlay/tun_android.go b/overlay/tun_android.go index df1ed8d8..eddef882 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -95,6 +95,10 @@ func (t *tun) Name() string { return "android" } +func (t *tun) SupportsMultiqueue() bool { + return false +} + func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for android") } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 7f6ba4f0..128c2001 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "net" "net/netip" "os" "sync/atomic" @@ -295,7 +294,6 @@ func (t *tun) activate6(network netip.Prefix) error { Vltime: 0xffffffff, Pltime: 0xffffffff, }, - //TODO: CERT-V2 should we disable DAD (duplicate address detection) and mark this as a secured address? Flags: _IN6_IFF_NODAD, } @@ -551,16 +549,10 @@ func (t *tun) Name() string { return t.Device } +func (t *tun) SupportsMultiqueue() bool { + return false +} + func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } - -func prefixToMask(prefix netip.Prefix) netip.Addr { - pLen := 128 - if prefix.Addr().Is4() { - pLen = 32 - } - - addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen)) - return addr -} diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index 131879d2..aa3dddaf 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -105,6 +105,10 @@ func (t *disabledTun) Write(b []byte) (int, error) { return len(b), nil } +func (t *disabledTun) SupportsMultiqueue() bool { + return true +} + func (t *disabledTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return t, nil } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 2a89cbca..8d292263 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -10,11 +10,9 @@ import ( "io" "io/fs" "net/netip" - "os" - "os/exec" - "strconv" "sync/atomic" "syscall" + "time" "unsafe" "github.com/gaissmai/bart" @@ -22,12 +20,18 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" + netroute "golang.org/x/net/route" + "golang.org/x/sys/unix" ) const ( // FIODGNAME is defined in sys/sys/filio.h on FreeBSD // For 32-bit systems, use FIODGNAME_32 (not defined in this file: 0x80086678) - FIODGNAME = 0x80106678 + FIODGNAME = 0x80106678 + TUNSIFMODE = 0x8004745e + TUNSIFHEAD = 0x80047460 + OSIOCAIFADDR_IN6 = 0x8088691b + IN6_IFF_NODAD = 0x0020 ) type fiodgnameArg struct { @@ -37,43 +41,159 @@ type fiodgnameArg struct { } type ifreqRename struct { - Name [16]byte + Name [unix.IFNAMSIZ]byte Data uintptr } type ifreqDestroy struct { - Name [16]byte + Name [unix.IFNAMSIZ]byte pad [16]byte } +type ifReq struct { + Name [unix.IFNAMSIZ]byte + Flags uint16 +} + +type ifreqMTU struct { + Name [unix.IFNAMSIZ]byte + MTU int32 +} + +type addrLifetime struct { + Expire uint64 + Preferred uint64 + Vltime uint32 + Pltime uint32 +} + +type ifreqAlias4 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet4 + DstAddr unix.RawSockaddrInet4 + MaskAddr unix.RawSockaddrInet4 + VHid uint32 +} + +type ifreqAlias6 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet6 + DstAddr unix.RawSockaddrInet6 + PrefixMask unix.RawSockaddrInet6 + Flags uint32 + Lifetime addrLifetime + VHid uint32 +} + type tun struct { Device string vpnNetworks []netip.Prefix MTU int Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] + linkAddr *netroute.LinkAddr l *logrus.Logger + devFd int +} - io.ReadWriteCloser +func (t *tun) Read(to []byte) (int, error) { + // use readv() to read from the tunnel device, to eliminate the need for copying the buffer + if t.devFd < 0 { + return -1, syscall.EINVAL + } + + // first 4 bytes is protocol family, in network byte order + head := make([]byte, 4) + + iovecs := []syscall.Iovec{ + {&head[0], 4}, + {&to[0], uint64(len(to))}, + } + + n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) + + var err error + if errno != 0 { + err = syscall.Errno(errno) + } else { + err = nil + } + // fix bytes read number to exclude header + bytesRead := int(n) + if bytesRead < 0 { + return bytesRead, err + } else if bytesRead < 4 { + return 0, err + } else { + return bytesRead - 4, err + } +} + +// Write is only valid for single threaded use +func (t *tun) Write(from []byte) (int, error) { + // use writev() to write to the tunnel device, to eliminate the need for copying the buffer + if t.devFd < 0 { + return -1, syscall.EINVAL + } + + if len(from) <= 1 { + return 0, syscall.EIO + } + ipVer := from[0] >> 4 + var head []byte + // first 4 bytes is protocol family, in network byte order + if ipVer == 4 { + head = []byte{0, 0, 0, syscall.AF_INET} + } else if ipVer == 6 { + head = []byte{0, 0, 0, syscall.AF_INET6} + } else { + return 0, fmt.Errorf("unable to determine IP version from packet") + } + iovecs := []syscall.Iovec{ + {&head[0], 4}, + {&from[0], uint64(len(from))}, + } + + n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.devFd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) + + var err error + if errno != 0 { + err = syscall.Errno(errno) + } else { + err = nil + } + + return int(n) - 4, err } func (t *tun) Close() error { - if t.ReadWriteCloser != nil { - if err := t.ReadWriteCloser.Close(); err != nil { - return err - } - - s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) + if t.devFd >= 0 { + err := syscall.Close(t.devFd) if err != nil { - return err + t.l.WithError(err).Error("Error closing device") } - defer syscall.Close(s) + t.devFd = -1 - ifreq := ifreqDestroy{Name: t.deviceBytes()} + c := make(chan struct{}) + go func() { + // destroying the interface can block if a read() is still pending. Do this asynchronously. + defer close(c) + s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) + if err == nil { + defer syscall.Close(s) + ifreq := ifreqDestroy{Name: t.deviceBytes()} + err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) + } + if err != nil { + t.l.WithError(err).Error("Error destroying tunnel") + } + }() - // Destroy the interface - err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) - return err + // wait up to 1 second so we start blocking at the ioctl + select { + case <-c: + case <-time.After(1 * time.Second): + } } return nil @@ -85,32 +205,37 @@ func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device - var file *os.File + var fd int var err error deviceName := c.GetString("tun.dev", "") if deviceName != "" { - file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) + fd, err = syscall.Open("/dev/"+deviceName, syscall.O_RDWR, 0) } if errors.Is(err, fs.ErrNotExist) || deviceName == "" { // If the device doesn't already exist, request a new one and rename it - file, err = os.OpenFile("/dev/tun", os.O_RDWR, 0) + fd, err = syscall.Open("/dev/tun", syscall.O_RDWR, 0) } if err != nil { return nil, err } - rawConn, err := file.SyscallConn() - if err != nil { - return nil, fmt.Errorf("SyscallConn: %v", err) + // Read the name of the interface + var name [16]byte + arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)} + ctrlErr := ioctl(uintptr(fd), FIODGNAME, uintptr(unsafe.Pointer(&arg))) + + if ctrlErr == nil { + // set broadcast mode and multicast + ifmode := uint32(unix.IFF_BROADCAST | unix.IFF_MULTICAST) + ctrlErr = ioctl(uintptr(fd), TUNSIFMODE, uintptr(unsafe.Pointer(&ifmode))) + } + + if ctrlErr == nil { + // turn on link-layer mode, to support ipv6 + ifhead := uint32(1) + ctrlErr = ioctl(uintptr(fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&ifhead))) } - var name [16]byte - var ctrlErr error - rawConn.Control(func(fd uintptr) { - // Read the name of the interface - arg := fiodgnameArg{length: 16, buf: unsafe.Pointer(&name)} - ctrlErr = ioctl(fd, FIODGNAME, uintptr(unsafe.Pointer(&arg))) - }) if ctrlErr != nil { return nil, err } @@ -122,11 +247,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( // If the name doesn't match the desired interface name, rename it now if ifName != deviceName { - s, err := syscall.Socket( - syscall.AF_INET, - syscall.SOCK_DGRAM, - syscall.IPPROTO_IP, - ) + s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) if err != nil { return nil, err } @@ -149,11 +270,11 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } t := &tun{ - ReadWriteCloser: file, - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + devFd: fd, } err = t.reload(c, true) @@ -172,38 +293,111 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( } func (t *tun) addIp(cidr netip.Prefix) error { - var err error - // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) - t.l.Debug("command: ", cmd.String()) - if err = cmd.Run(); err != nil { - return fmt.Errorf("failed to run 'ifconfig': %s", err) + if cidr.Addr().Is4() { + ifr := ifreqAlias4{ + Name: t.deviceBytes(), + Addr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: cidr.Addr().As4(), + }, + DstAddr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: getBroadcast(cidr).As4(), + }, + MaskAddr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: prefixToMask(cidr).As4(), + }, + VHid: 0, + } + s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + // Note: unix.SIOCAIFADDR corresponds to FreeBSD's OSIOCAIFADDR + if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err) + } + return nil } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device) - t.l.Debug("command: ", cmd.String()) - if err = cmd.Run(); err != nil { - return fmt.Errorf("failed to run 'route add': %s", err) + if cidr.Addr().Is6() { + ifr := ifreqAlias6{ + Name: t.deviceBytes(), + Addr: unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: cidr.Addr().As16(), + }, + PrefixMask: unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: prefixToMask(cidr).As16(), + }, + Lifetime: addrLifetime{ + Expire: 0, + Preferred: 0, + Vltime: 0xffffffff, + Pltime: 0xffffffff, + }, + Flags: IN6_IFF_NODAD, + } + s, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + + if err := ioctl(uintptr(s), OSIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err) + } + return nil } - cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) - t.l.Debug("command: ", cmd.String()) - if err = cmd.Run(); err != nil { - return fmt.Errorf("failed to run 'ifconfig': %s", err) - } - - // Unsafe path routes - return t.addRoutes(false) + return fmt.Errorf("unknown address type %v", cidr) } func (t *tun) Activate() error { + // Setup our default MTU + err := t.setMTU() + if err != nil { + return err + } + + linkAddr, err := getLinkAddr(t.Device) + if err != nil { + return err + } + if linkAddr == nil { + return fmt.Errorf("unable to discover link_addr for tun interface") + } + t.linkAddr = linkAddr + for i := range t.vpnNetworks { err := t.addIp(t.vpnNetworks[i]) if err != nil { return err } } - return nil + + return t.addRoutes(false) +} + +func (t *tun) setMTU() error { + // Set the MTU on the device + s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + + ifm := ifreqMTU{Name: t.deviceBytes(), MTU: int32(t.MTU)} + err = ioctl(uintptr(s), unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))) + return err } func (t *tun) reload(c *config.C, initial bool) error { @@ -256,6 +450,10 @@ func (t *tun) Name() string { return t.Device } +func (t *tun) SupportsMultiqueue() bool { + return false +} + func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") } @@ -268,15 +466,16 @@ func (t *tun) addRoutes(logErrors bool) error { continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), "-interface", t.Device) - t.l.Debug("command: ", cmd.String()) - if err := cmd.Run(); err != nil { - retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err) + err := addRoute(r.Cidr, t.linkAddr) + if err != nil { + retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { return retErr } + } else { + t.l.WithField("route", r).Info("Added route") } } @@ -289,9 +488,8 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), "-interface", t.Device) - t.l.Debug("command: ", cmd.String()) - if err := cmd.Run(); err != nil { + err := delRoute(r.Cidr, t.linkAddr) + if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { t.l.WithField("route", r).Info("Removed route") @@ -306,3 +504,120 @@ func (t *tun) deviceBytes() (o [16]byte) { } return } + +func addRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := &netroute.RouteMessage{ + Version: unix.RTM_VERSION, + Type: unix.RTM_ADD, + Flags: unix.RTF_UP, + Seq: 1, + } + + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage: %w", err) + } + + _, err = unix.Write(sock, data[:]) + if err != nil { + if errors.Is(err, unix.EEXIST) { + // Try to do a change + route.Type = unix.RTM_CHANGE + data, err = route.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage for change: %w", err) + } + _, err = unix.Write(sock, data[:]) + fmt.Println("DOING CHANGE") + return err + } + return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) + } + + return nil +} + +func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := netroute.RouteMessage{ + Version: unix.RTM_VERSION, + Type: unix.RTM_DELETE, + Seq: 1, + } + + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage: %w", err) + } + _, err = unix.Write(sock, data[:]) + if err != nil { + return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) + } + + return nil +} + +// getLinkAddr Gets the link address for the interface of the given name +func getLinkAddr(name string) (*netroute.LinkAddr, error) { + rib, err := netroute.FetchRIB(unix.AF_UNSPEC, unix.NET_RT_IFLIST, 0) + if err != nil { + return nil, err + } + msgs, err := netroute.ParseRIB(unix.NET_RT_IFLIST, rib) + if err != nil { + return nil, err + } + + for _, m := range msgs { + switch m := m.(type) { + case *netroute.InterfaceMessage: + if m.Name == name { + sa, ok := m.Addrs[unix.RTAX_IFP].(*netroute.LinkAddr) + if ok { + return sa, nil + } + } + } + } + + return nil, nil +} diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index e51e1120..0ce01df8 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -151,6 +151,10 @@ func (t *tun) Name() string { return "iOS" } +func (t *tun) SupportsMultiqueue() bool { + return false +} + func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for ios") } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 4c509ba8..32bf51f5 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -216,6 +216,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } +func (t *tun) SupportsMultiqueue() bool { + return true +} + func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { @@ -293,7 +297,6 @@ func (t *tun) addIPs(link netlink.Link) error { //add all new addresses for i := range newAddrs { - //TODO: CERT-V2 do we want to stack errors and try as many ops as possible? //AddrReplace still adds new IPs, but if their properties change it will change them as well if err := netlink.AddrReplace(link, newAddrs[i]); err != nil { return err @@ -361,6 +364,11 @@ func (t *tun) Activate() error { t.l.WithError(err).Error("Failed to set tun tx queue length") } + const modeNone = 1 + if err = netlink.LinkSetIP6AddrGenMode(link, modeNone); err != nil { + t.l.WithError(err).Warn("Failed to disable link local address generation") + } + if err = t.addIPs(link); err != nil { return err } @@ -578,48 +586,42 @@ func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool { } func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { - var gateways routing.Gateways link, err := netlink.LinkByName(t.Device) if err != nil { - t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name") + t.l.WithField("deviceName", t.Device).Error("Ignoring route update: failed to get link by name") return gateways } // If this route is relevant to our interface and there is a gateway then add it - if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 { - gwAddr, ok := netip.AddrFromSlice(r.Gw) - if !ok { - t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") - } else { - gwAddr = gwAddr.Unmap() - - if !t.isGatewayInVpnNetworks(gwAddr) { - // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not in our network") - } else { + if r.LinkIndex == link.Attrs().Index { + gwAddr, ok := getGatewayAddr(r.Gw, r.Via) + if ok { + if t.isGatewayInVpnNetworks(gwAddr) { gateways = append(gateways, routing.NewGateway(gwAddr, 1)) + } else { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") } + } else { + t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") } } for _, p := range r.MultiPath { // If this route is relevant to our interface and there is a gateway then add it - if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 { - gwAddr, ok := netip.AddrFromSlice(p.Gw) - if !ok { - t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address") - } else { - gwAddr = gwAddr.Unmap() - - if !t.isGatewayInVpnNetworks(gwAddr) { - // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not in our network") - } else { - // p.Hops+1 = weight of the route + if p.LinkIndex == link.Attrs().Index { + gwAddr, ok := getGatewayAddr(p.Gw, p.Via) + if ok { + if t.isGatewayInVpnNetworks(gwAddr) { gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1)) + } else { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, gateway is not in our network") } + } else { + t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway or via address") } } } @@ -628,16 +630,38 @@ func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways { return gateways } +func getGatewayAddr(gw net.IP, via netlink.Destination) (netip.Addr, bool) { + // Try to use the old RTA_GATEWAY first + gwAddr, ok := netip.AddrFromSlice(gw) + if !ok { + // Fallback to the new RTA_VIA + rVia, ok := via.(*netlink.Via) + if ok { + gwAddr, ok = netip.AddrFromSlice(rVia.Addr) + } + } + + if gwAddr.IsValid() { + gwAddr = gwAddr.Unmap() + return gwAddr, true + } + + return netip.Addr{}, false +} + func (t *tun) updateRoutes(r netlink.RouteUpdate) { - gateways := t.getGatewaysFromRoute(&r.Route) - if len(gateways) == 0 { // No gateways relevant to our network, no routing changes required. t.l.WithField("route", r).Debug("Ignoring route update, no gateways") return } + if r.Dst == nil { + t.l.WithField("route", r).Debug("Ignoring route update, no destination address") + return + } + dstAddr, ok := netip.AddrFromSlice(r.Dst.IP) if !ok { t.l.WithField("route", r).Debug("Ignoring route update, invalid destination address") diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 5ff9b0fe..2986c895 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -4,13 +4,12 @@ package overlay import ( + "errors" "fmt" "io" "net/netip" "os" - "os/exec" "regexp" - "strconv" "sync/atomic" "syscall" "unsafe" @@ -20,11 +19,42 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" + netroute "golang.org/x/net/route" + "golang.org/x/sys/unix" ) -type ifreqDestroy struct { - Name [16]byte - pad [16]byte +const ( + SIOCAIFADDR_IN6 = 0x8080696b + TUNSIFHEAD = 0x80047442 + TUNSIFMODE = 0x80047458 +) + +type ifreqAlias4 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet4 + DstAddr unix.RawSockaddrInet4 + MaskAddr unix.RawSockaddrInet4 +} + +type ifreqAlias6 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet6 + DstAddr unix.RawSockaddrInet6 + PrefixMask unix.RawSockaddrInet6 + Flags uint32 + Lifetime addrLifetime +} + +type ifreq struct { + Name [unix.IFNAMSIZ]byte + data int +} + +type addrLifetime struct { + Expire uint64 + Preferred uint64 + Vltime uint32 + Pltime uint32 } type tun struct { @@ -34,40 +64,18 @@ type tun struct { Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger - - io.ReadWriteCloser + f *os.File + fd int } -func (t *tun) Close() error { - if t.ReadWriteCloser != nil { - if err := t.ReadWriteCloser.Close(); err != nil { - return err - } - - s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) - if err != nil { - return err - } - defer syscall.Close(s) - - ifreq := ifreqDestroy{Name: t.deviceBytes()} - - err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifreq))) - - return err - } - return nil -} +var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } -var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) - func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device - var file *os.File var err error deviceName := c.GetString("tun.dev", "") if deviceName == "" { @@ -77,17 +85,23 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") } - file, err = os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) + fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0) if err != nil { return nil, err } + err = unix.SetNonblock(fd, true) + if err != nil { + l.WithError(err).Warn("Failed to set the tun device as nonblocking") + } + t := &tun{ - ReadWriteCloser: file, - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + f: os.NewFile(uintptr(fd), ""), + fd: fd, + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, } err = t.reload(c, true) @@ -105,40 +119,225 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( return t, nil } +func (t *tun) Close() error { + if t.f != nil { + if err := t.f.Close(); err != nil { + return fmt.Errorf("error closing tun file: %w", err) + } + + // t.f.Close should have handled it for us but let's be extra sure + _ = unix.Close(t.fd) + + s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + + ifr := ifreq{Name: t.deviceBytes()} + err = ioctl(uintptr(s), syscall.SIOCIFDESTROY, uintptr(unsafe.Pointer(&ifr))) + return err + } + return nil +} + +func (t *tun) Read(to []byte) (int, error) { + rc, err := t.f.SyscallConn() + if err != nil { + return 0, fmt.Errorf("failed to get syscall conn for tun: %w", err) + } + + var errno syscall.Errno + var n uintptr + err = rc.Read(func(fd uintptr) bool { + // first 4 bytes is protocol family, in network byte order + head := [4]byte{} + iovecs := []syscall.Iovec{ + {&head[0], 4}, + {&to[0], uint64(len(to))}, + } + + n, _, errno = syscall.Syscall(syscall.SYS_READV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) + if errno.Temporary() { + // We got an EAGAIN, EINTR, or EWOULDBLOCK, go again + return false + } + return true + }) + if err != nil { + if err == syscall.EBADF || err.Error() == "use of closed file" { + // Go doesn't export poll.ErrFileClosing but happily reports it to us so here we are + // https://github.com/golang/go/blob/master/src/internal/poll/fd_poll_runtime.go#L121 + return 0, os.ErrClosed + } + return 0, fmt.Errorf("failed to make read call for tun: %w", err) + } + + if errno != 0 { + return 0, fmt.Errorf("failed to make inner read call for tun: %w", errno) + } + + // fix bytes read number to exclude header + bytesRead := int(n) + if bytesRead < 0 { + return bytesRead, nil + } else if bytesRead < 4 { + return 0, nil + } else { + return bytesRead - 4, nil + } +} + +// Write is only valid for single threaded use +func (t *tun) Write(from []byte) (int, error) { + if len(from) <= 1 { + return 0, syscall.EIO + } + + ipVer := from[0] >> 4 + var head [4]byte + // first 4 bytes is protocol family, in network byte order + if ipVer == 4 { + head[3] = syscall.AF_INET + } else if ipVer == 6 { + head[3] = syscall.AF_INET6 + } else { + return 0, fmt.Errorf("unable to determine IP version from packet") + } + + rc, err := t.f.SyscallConn() + if err != nil { + return 0, err + } + + var errno syscall.Errno + var n uintptr + err = rc.Write(func(fd uintptr) bool { + iovecs := []syscall.Iovec{ + {&head[0], 4}, + {&from[0], uint64(len(from))}, + } + + n, _, errno = syscall.Syscall(syscall.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecs[0])), uintptr(2)) + // According to NetBSD documentation for TUN, writes will only return errors in which + // this packet will never be delivered so just go on living life. + return true + }) + if err != nil { + return 0, err + } + + if errno != 0 { + return 0, errno + } + + return int(n) - 4, err +} + func (t *tun) addIp(cidr netip.Prefix) error { - var err error + if cidr.Addr().Is4() { + var req ifreqAlias4 + req.Name = t.deviceBytes() + req.Addr = unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: cidr.Addr().As4(), + } + req.DstAddr = unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: cidr.Addr().As4(), + } + req.MaskAddr = unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: prefixToMask(cidr).As4(), + } - // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) - t.l.Debug("command: ", cmd.String()) - if err = cmd.Run(); err != nil { - return fmt.Errorf("failed to run 'ifconfig': %s", err) + s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + + if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil { + return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err) + } + + return nil } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String()) - t.l.Debug("command: ", cmd.String()) - if err = cmd.Run(); err != nil { - return fmt.Errorf("failed to run 'route add': %s", err) + if cidr.Addr().Is6() { + var req ifreqAlias6 + req.Name = t.deviceBytes() + req.Addr = unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: cidr.Addr().As16(), + } + req.PrefixMask = unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: prefixToMask(cidr).As16(), + } + req.Lifetime = addrLifetime{ + Vltime: 0xffffffff, + Pltime: 0xffffffff, + } + + s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + + if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil { + return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err) + } + return nil } - cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) - t.l.Debug("command: ", cmd.String()) - if err = cmd.Run(); err != nil { - return fmt.Errorf("failed to run 'ifconfig': %s", err) - } - - // Unsafe path routes - return t.addRoutes(false) + return fmt.Errorf("unknown address type %v", cidr) } func (t *tun) Activate() error { + mode := int32(unix.IFF_BROADCAST) + err := ioctl(uintptr(t.fd), TUNSIFMODE, uintptr(unsafe.Pointer(&mode))) + if err != nil { + return fmt.Errorf("failed to set tun device mode: %w", err) + } + + v := 1 + err = ioctl(uintptr(t.fd), TUNSIFHEAD, uintptr(unsafe.Pointer(&v))) + if err != nil { + return fmt.Errorf("failed to set tun device head: %w", err) + } + + err = t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU)) + if err != nil { + return fmt.Errorf("failed to set tun mtu: %w", err) + } + for i := range t.vpnNetworks { - err := t.addIp(t.vpnNetworks[i]) + err = t.addIp(t.vpnNetworks[i]) if err != nil { return err } } - return nil + + return t.addRoutes(false) +} + +func (t *tun) doIoctlByName(ctl uintptr, value uint32) error { + s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + + ir := ifreq{Name: t.deviceBytes(), data: int(value)} + err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir))) + return err } func (t *tun) reload(c *config.C, initial bool) error { @@ -191,27 +390,33 @@ func (t *tun) Name() string { return t.Device } +func (t *tun) SupportsMultiqueue() bool { + return false +} + func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd") } func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() + for _, r := range routes { if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) - t.l.Debug("command: ", cmd.String()) - if err := cmd.Run(); err != nil { - retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err) + err := addRoute(r.Cidr, t.vpnNetworks) + if err != nil { + retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { return retErr } + } else { + t.l.WithField("route", r).Info("Added route") } } @@ -224,10 +429,8 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - //TODO: CERT-V2 is this right? - cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) - t.l.Debug("command: ", cmd.String()) - if err := cmd.Run(); err != nil { + err := delRoute(r.Cidr, t.vpnNetworks) + if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { t.l.WithField("route", r).Info("Removed route") @@ -242,3 +445,109 @@ func (t *tun) deviceBytes() (o [16]byte) { } return } + +func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := &netroute.RouteMessage{ + Version: unix.RTM_VERSION, + Type: unix.RTM_ADD, + Flags: unix.RTF_UP | unix.RTF_GATEWAY, + Seq: 1, + } + + if prefix.Addr().Is4() { + gw, err := selectGateway(prefix, gateways) + if err != nil { + return err + } + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()}, + } + } else { + gw, err := selectGateway(prefix, gateways) + if err != nil { + return err + } + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()}, + } + } + + data, err := route.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage: %w", err) + } + + _, err = unix.Write(sock, data[:]) + if err != nil { + if errors.Is(err, unix.EEXIST) { + // Try to do a change + route.Type = unix.RTM_CHANGE + data, err = route.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage for change: %w", err) + } + _, err = unix.Write(sock, data[:]) + return err + } + return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) + } + + return nil +} + +func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := netroute.RouteMessage{ + Version: unix.RTM_VERSION, + Type: unix.RTM_DELETE, + Seq: 1, + } + + if prefix.Addr().Is4() { + gw, err := selectGateway(prefix, gateways) + if err != nil { + return err + } + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()}, + } + } else { + gw, err := selectGateway(prefix, gateways) + if err != nil { + return err + } + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()}, + } + } + + data, err := route.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage: %w", err) + } + _, err = unix.Write(sock, data[:]) + if err != nil { + return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) + } + + return nil +} diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 67a9a5f8..9209b795 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -4,23 +4,50 @@ package overlay import ( + "errors" "fmt" "io" "net/netip" "os" - "os/exec" "regexp" - "strconv" "sync/atomic" "syscall" + "unsafe" "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/util" + netroute "golang.org/x/net/route" + "golang.org/x/sys/unix" ) +const ( + SIOCAIFADDR_IN6 = 0x8080691a +) + +type ifreqAlias4 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet4 + DstAddr unix.RawSockaddrInet4 + MaskAddr unix.RawSockaddrInet4 +} + +type ifreqAlias6 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet6 + DstAddr unix.RawSockaddrInet6 + PrefixMask unix.RawSockaddrInet6 + Flags uint32 + Lifetime [2]uint32 +} + +type ifreq struct { + Name [unix.IFNAMSIZ]byte + data int +} + type tun struct { Device string vpnNetworks []netip.Prefix @@ -28,48 +55,46 @@ type tun struct { Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] l *logrus.Logger - - io.ReadWriteCloser - + f *os.File + fd int // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte } -func (t *tun) Close() error { - if t.ReadWriteCloser != nil { - return t.ReadWriteCloser.Close() - } - - return nil -} - -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { - return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") -} - var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { + return nil, fmt.Errorf("newTunFromFd not supported in openbsd") +} + func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { + // Try to open tun device + var err error deviceName := c.GetString("tun.dev", "") if deviceName == "" { - return nil, fmt.Errorf("a device name in the format of tunN must be specified") + return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") } - if !deviceNameRE.MatchString(deviceName) { - return nil, fmt.Errorf("a device name in the format of tunN must be specified") + return nil, fmt.Errorf("a device name in the format of /dev/tunN must be specified") } - file, err := os.OpenFile("/dev/"+deviceName, os.O_RDWR, 0) + fd, err := unix.Open("/dev/"+deviceName, os.O_RDWR, 0) if err != nil { return nil, err } + err = unix.SetNonblock(fd, true) + if err != nil { + l.WithError(err).Warn("Failed to set the tun device as nonblocking") + } + t := &tun{ - ReadWriteCloser: file, - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + f: os.NewFile(uintptr(fd), ""), + fd: fd, + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, } err = t.reload(c, true) @@ -87,6 +112,154 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) ( return t, nil } +func (t *tun) Close() error { + if t.f != nil { + if err := t.f.Close(); err != nil { + return fmt.Errorf("error closing tun file: %w", err) + } + + // t.f.Close should have handled it for us but let's be extra sure + _ = unix.Close(t.fd) + } + return nil +} + +func (t *tun) Read(to []byte) (int, error) { + buf := make([]byte, len(to)+4) + + n, err := t.f.Read(buf) + + copy(to, buf[4:]) + return n - 4, err +} + +// Write is only valid for single threaded use +func (t *tun) Write(from []byte) (int, error) { + buf := t.out + if cap(buf) < len(from)+4 { + buf = make([]byte, len(from)+4) + t.out = buf + } + buf = buf[:len(from)+4] + + if len(from) == 0 { + return 0, syscall.EIO + } + + // Determine the IP Family for the NULL L2 Header + ipVer := from[0] >> 4 + if ipVer == 4 { + buf[3] = syscall.AF_INET + } else if ipVer == 6 { + buf[3] = syscall.AF_INET6 + } else { + return 0, fmt.Errorf("unable to determine IP version from packet") + } + + copy(buf[4:], from) + + n, err := t.f.Write(buf) + return n - 4, err +} + +func (t *tun) addIp(cidr netip.Prefix) error { + if cidr.Addr().Is4() { + var req ifreqAlias4 + req.Name = t.deviceBytes() + req.Addr = unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: cidr.Addr().As4(), + } + req.DstAddr = unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: cidr.Addr().As4(), + } + req.MaskAddr = unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: prefixToMask(cidr).As4(), + } + + s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + + if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&req))); err != nil { + return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr(), err) + } + + err = addRoute(cidr, t.vpnNetworks) + if err != nil { + return fmt.Errorf("failed to set route for vpn network %v: %w", cidr, err) + } + + return nil + } + + if cidr.Addr().Is6() { + var req ifreqAlias6 + req.Name = t.deviceBytes() + req.Addr = unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: cidr.Addr().As16(), + } + req.PrefixMask = unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: prefixToMask(cidr).As16(), + } + req.Lifetime[0] = 0xffffffff + req.Lifetime[1] = 0xffffffff + + s, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + + if err := ioctl(uintptr(s), SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&req))); err != nil { + return fmt.Errorf("failed to set tun address %s: %s", cidr.Addr().String(), err) + } + + return nil + } + + return fmt.Errorf("unknown address type %v", cidr) +} + +func (t *tun) Activate() error { + err := t.doIoctlByName(unix.SIOCSIFMTU, uint32(t.MTU)) + if err != nil { + return fmt.Errorf("failed to set tun mtu: %w", err) + } + + for i := range t.vpnNetworks { + err = t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + + return t.addRoutes(false) +} + +func (t *tun) doIoctlByName(ctl uintptr, value uint32) error { + s, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP) + if err != nil { + return err + } + defer syscall.Close(s) + + ir := ifreq{Name: t.deviceBytes(), data: int(value)} + err = ioctl(uintptr(s), ctl, uintptr(unsafe.Pointer(&ir))) + return err +} + func (t *tun) reload(c *config.C, initial bool) error { change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { @@ -124,63 +297,46 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) addIp(cidr netip.Prefix) error { - var err error - // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) - t.l.Debug("command: ", cmd.String()) - if err = cmd.Run(); err != nil { - return fmt.Errorf("failed to run 'ifconfig': %s", err) - } - - cmd = exec.Command("/sbin/ifconfig", t.Device, "mtu", strconv.Itoa(t.MTU)) - t.l.Debug("command: ", cmd.String()) - if err = cmd.Run(); err != nil { - return fmt.Errorf("failed to run 'ifconfig': %s", err) - } - - cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String()) - t.l.Debug("command: ", cmd.String()) - if err = cmd.Run(); err != nil { - return fmt.Errorf("failed to run 'route add': %s", err) - } - - // Unsafe path routes - return t.addRoutes(false) -} - -func (t *tun) Activate() error { - for i := range t.vpnNetworks { - err := t.addIp(t.vpnNetworks[i]) - if err != nil { - return err - } - } - return nil -} - func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { r, _ := t.routeTree.Load().Lookup(ip) return r } +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks +} + +func (t *tun) Name() string { + return t.Device +} + +func (t *tun) SupportsMultiqueue() bool { + return false +} + +func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd") +} + func (t *tun) addRoutes(logErrors bool) error { routes := *t.Routes.Load() + for _, r := range routes { if len(r.Via) == 0 || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - //TODO: CERT-V2 is this right? - cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) - t.l.Debug("command: ", cmd.String()) - if err := cmd.Run(); err != nil { - retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]any{"route": r}, err) + + err := addRoute(r.Cidr, t.vpnNetworks) + if err != nil { + retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { retErr.Log(t.l) } else { return retErr } + } else { + t.l.WithField("route", r).Info("Added route") } } @@ -192,10 +348,9 @@ func (t *tun) removeRoutes(routes []Route) error { if !r.Install { continue } - //TODO: CERT-V2 is this right? - cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) - t.l.Debug("command: ", cmd.String()) - if err := cmd.Run(); err != nil { + + err := delRoute(r.Cidr, t.vpnNetworks) + if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { t.l.WithField("route", r).Info("Removed route") @@ -204,52 +359,115 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func (t *tun) Networks() []netip.Prefix { - return t.vpnNetworks -} - -func (t *tun) Name() string { - return t.Device -} - -func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd") -} - -func (t *tun) Read(to []byte) (int, error) { - buf := make([]byte, len(to)+4) - - n, err := t.ReadWriteCloser.Read(buf) - - copy(to, buf[4:]) - return n - 4, err -} - -// Write is only valid for single threaded use -func (t *tun) Write(from []byte) (int, error) { - buf := t.out - if cap(buf) < len(from)+4 { - buf = make([]byte, len(from)+4) - t.out = buf +func (t *tun) deviceBytes() (o [16]byte) { + for i, c := range t.Device { + o[i] = byte(c) } - buf = buf[:len(from)+4] + return +} - if len(from) == 0 { - return 0, syscall.EIO +func addRoute(prefix netip.Prefix, gateways []netip.Prefix) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := &netroute.RouteMessage{ + Version: unix.RTM_VERSION, + Type: unix.RTM_ADD, + Flags: unix.RTF_UP | unix.RTF_GATEWAY, + Seq: 1, } - // Determine the IP Family for the NULL L2 Header - ipVer := from[0] >> 4 - if ipVer == 4 { - buf[3] = syscall.AF_INET - } else if ipVer == 6 { - buf[3] = syscall.AF_INET6 + if prefix.Addr().Is4() { + gw, err := selectGateway(prefix, gateways) + if err != nil { + return err + } + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()}, + } } else { - return 0, fmt.Errorf("unable to determine IP version from packet") + gw, err := selectGateway(prefix, gateways) + if err != nil { + return err + } + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()}, + } } - copy(buf[4:], from) + data, err := route.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage: %w", err) + } - n, err := t.ReadWriteCloser.Write(buf) - return n - 4, err + _, err = unix.Write(sock, data[:]) + if err != nil { + if errors.Is(err, unix.EEXIST) { + // Try to do a change + route.Type = unix.RTM_CHANGE + data, err = route.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage for change: %w", err) + } + _, err = unix.Write(sock, data[:]) + return err + } + return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) + } + + return nil +} + +func delRoute(prefix netip.Prefix, gateways []netip.Prefix) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := netroute.RouteMessage{ + Version: unix.RTM_VERSION, + Type: unix.RTM_DELETE, + Seq: 1, + } + + if prefix.Addr().Is4() { + gw, err := selectGateway(prefix, gateways) + if err != nil { + return err + } + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: &netroute.Inet4Addr{IP: gw.Addr().As4()}, + } + } else { + gw, err := selectGateway(prefix, gateways) + if err != nil { + return err + } + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: &netroute.Inet6Addr{IP: gw.Addr().As16()}, + } + } + + data, err := route.Marshal() + if err != nil { + return fmt.Errorf("failed to create route.RouteMessage: %w", err) + } + _, err = unix.Write(sock, data[:]) + if err != nil { + return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) + } + + return nil } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index b6712fbb..3477de3d 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -132,6 +132,10 @@ func (t *TestTun) Read(b []byte) (int, error) { return len(p), nil } +func (t *TestTun) SupportsMultiqueue() bool { + return false +} + func (t *TestTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented") } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 7aac1289..b4d78b66 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -234,6 +234,10 @@ func (t *winTun) Write(b []byte) (int, error) { return t.tun.Write(b, 0) } +func (t *winTun) SupportsMultiqueue() bool { + return false +} + func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") } diff --git a/overlay/user.go b/overlay/user.go index 8a56d667..1f92d4e9 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -46,6 +46,10 @@ func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways { return routing.Gateways{routing.NewGateway(ip, 1)} } +func (d *UserDevice) SupportsMultiqueue() bool { + return true +} + func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return d, nil } diff --git a/pkclient/pkclient_cgo.go b/pkclient/pkclient_cgo.go index a2ead551..e6286093 100644 --- a/pkclient/pkclient_cgo.go +++ b/pkclient/pkclient_cgo.go @@ -180,6 +180,7 @@ func (c *PKClient) DeriveNoise(peerPubKey []byte) ([]byte, error) { pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, true), pkcs11.NewAttribute(pkcs11.CKA_WRAP, true), pkcs11.NewAttribute(pkcs11.CKA_UNWRAP, true), + pkcs11.NewAttribute(pkcs11.CKA_VALUE_LEN, NoiseKeySize), } // Set up the parameters which include the peer's public key diff --git a/pki.go b/pki.go index 9cab4918..19869d58 100644 --- a/pki.go +++ b/pki.go @@ -100,55 +100,62 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { currentState := p.cs.Load() if newState.v1Cert != nil { if currentState.v1Cert == nil { - return util.NewContextualError("v1 certificate was added, restart required", nil, err) - } + //adding certs is fine, actually. Networks-in-common confirmed in newCertState(). + } else { + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1}, + nil, + ) + } - // did IP in cert change? if so, don't set - if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { - return util.NewContextualError( - "Networks in new cert was different from old", - m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()}, - nil, - ) + if currentState.v1Cert.Curve() != newState.v1Cert.Curve() { + return util.NewContextualError( + "Curve in new v1 cert was different from old", + m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1}, + nil, + ) + } } - - if currentState.v1Cert.Curve() != newState.v1Cert.Curve() { - return util.NewContextualError( - "Curve in new cert was different from old", - m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()}, - nil, - ) - } - - } else if currentState.v1Cert != nil { - //TODO: CERT-V2 we should be able to tear this down - return util.NewContextualError("v1 certificate was removed, restart required", nil, err) } if newState.v2Cert != nil { if currentState.v2Cert == nil { - return util.NewContextualError("v2 certificate was added, restart required", nil, err) - } + //adding certs is fine, actually + } else { + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2}, + nil, + ) + } - // did IP in cert change? if so, don't set - if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) { - return util.NewContextualError( - "Networks in new cert was different from old", - m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()}, - nil, - ) - } - - if currentState.v2Cert.Curve() != newState.v2Cert.Curve() { - return util.NewContextualError( - "Curve in new cert was different from old", - m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()}, - nil, - ) + if currentState.v2Cert.Curve() != newState.v2Cert.Curve() { + return util.NewContextualError( + "Curve in new cert was different from old", + m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2}, + nil, + ) + } } } else if currentState.v2Cert != nil { - return util.NewContextualError("v2 certificate was removed, restart required", nil, err) + //newState.v1Cert is non-nil bc empty certstates aren't permitted + if newState.v1Cert == nil { + return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err) + } + //if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs + if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) { + return util.NewContextualError( + "Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert", + m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()}, + nil, + ) + } } // Cipher cant be hot swapped so just leave it at what it was before @@ -173,7 +180,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { p.cs.Store(newState) - //TODO: CERT-V2 newState needs a stringer that does json if initial { p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") } else { @@ -359,7 +365,9 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil) } - //TODO: CERT-V2 make sure v2 has v1s address + if v1.Networks()[0] != v2.Networks()[0] { + return nil, util.NewContextualError("v1 and v2 networks are not the same", nil, nil) + } cs.initiatingVersion = dv } @@ -515,9 +523,13 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err) } - for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) { - l.WithField("fingerprint", fp).Info("Blocklisting cert") - caPool.BlocklistFingerprint(fp) + bl := c.GetStringSlice("pki.blocklist", []string{}) + if len(bl) > 0 { + for _, fp := range bl { + caPool.BlocklistFingerprint(fp) + } + + l.WithField("fingerprintCount", len(bl)).Info("Blocklisted certificates") } return caPool, nil diff --git a/remote_list.go b/remote_list.go index 6baed29b..1304fd51 100644 --- a/remote_list.go +++ b/remote_list.go @@ -190,7 +190,7 @@ type RemoteList struct { // The full list of vpn addresses assigned to this host vpnAddrs []netip.Addr - // A deduplicated set of addresses. Any accessor should lock beforehand. + // A deduplicated set of underlay addresses. Any accessor should lock beforehand. addrs []netip.AddrPort // A set of relay addresses. VpnIp addresses that the remote identified as relays. @@ -201,8 +201,10 @@ type RemoteList struct { // For learned addresses, this is the vpnIp that sent the packet cache map[netip.Addr]*cache - hr *hostnamesResults - shouldAdd func(netip.Addr) bool + hr *hostnamesResults + + // shouldAdd is a nillable function that decides if x should be added to addrs. + shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // They should not be tried again during a handshake @@ -213,7 +215,7 @@ type RemoteList struct { } // NewRemoteList creates a new empty RemoteList -func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList { +func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList { r := &RemoteList{ vpnAddrs: make([]netip.Addr, len(vpnAddrs)), addrs: make([]netip.AddrPort, 0), @@ -336,21 +338,21 @@ func (r *RemoteList) CopyCache() *CacheMap { } // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list -func (r *RemoteList) BlockRemote(bad netip.AddrPort) { - if !bad.IsValid() { - // relays can have nil udp Addrs +func (r *RemoteList) BlockRemote(bad ViaSender) { + if bad.IsRelayed { return } + r.Lock() defer r.Unlock() // Check if we already blocked this addr - if r.unlockedIsBad(bad) { + if r.unlockedIsBad(bad.UdpAddr) { return } // We copy here because we are taking something else's memory and we can't trust everything - r.badRemotes = append(r.badRemotes, bad) + r.badRemotes = append(r.badRemotes, bad.UdpAddr) // Mark the next interaction must recollect/dedupe r.shouldRebuild = true @@ -368,6 +370,15 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort { return c } +// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake +func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) { + r.Lock() + r.badRemotes = nil + r.vpnAddrs = make([]netip.Addr, len(vpnAddrs)) + copy(r.vpnAddrs, vpnAddrs) + r.Unlock() +} + // ResetBlockedRemotes locks and clears the blocked remotes list func (r *RemoteList) ResetBlockedRemotes() { r.Lock() @@ -577,7 +588,7 @@ func (r *RemoteList) unlockedCollect() { dnsAddrs := r.hr.GetAddrs() for _, addr := range dnsAddrs { - if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { + if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) { if !r.unlockedIsBad(addr) { addrs = append(addrs, addr) } diff --git a/service/service_test.go b/service/service_test.go index f1c91a72..c6b87423 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -16,8 +16,8 @@ import ( "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" + "go.yaml.in/yaml/v3" "golang.org/x/sync/errgroup" - "gopkg.in/yaml.v3" ) type m = map[string]any diff --git a/test/tun.go b/test/tun.go index ca65805f..fb32782f 100644 --- a/test/tun.go +++ b/test/tun.go @@ -34,6 +34,10 @@ func (NoopTun) Write([]byte) (int, error) { return 0, nil } +func (NoopTun) SupportsMultiqueue() bool { + return false +} + func (NoopTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, errors.New("unsupported") } diff --git a/udp/conn.go b/udp/conn.go index 895b0df3..1ae585c2 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -19,6 +19,7 @@ type Conn interface { ListenOut(r EncReader) WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) + SupportsMultipleReaders() bool Close() error } @@ -33,6 +34,9 @@ func (NoopConn) LocalAddr() (netip.AddrPort, error) { func (NoopConn) ListenOut(_ EncReader) { return } +func (NoopConn) SupportsMultipleReaders() bool { + return false +} func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index c0c6233c..91201194 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -98,9 +98,9 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { return ErrInvalidIPv6RemoteForSocket } - var rsa unix.RawSockaddrInet6 - rsa.Family = unix.AF_INET6 - rsa.Addr = ap.Addr().As16() + var rsa unix.RawSockaddrInet4 + rsa.Family = unix.AF_INET + rsa.Addr = ap.Addr().As4() binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ap.Port()) sa = unsafe.Pointer(&rsa) addrLen = syscall.SizeofSockaddrInet4 @@ -184,6 +184,10 @@ func (u *StdConn) ListenOut(r EncReader) { } } +func (u *StdConn) SupportsMultipleReaders() bool { + return false +} + func (u *StdConn) Rebind() error { var err error if u.isV4 { diff --git a/udp/udp_generic.go b/udp/udp_generic.go index cb21e574..3cefc904 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -85,3 +85,7 @@ func (u *GenericConn) ListenOut(r EncReader) { r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) } } + +func (u *GenericConn) SupportsMultipleReaders() bool { + return false +} diff --git a/udp/udp_linux.go b/udp/udp_linux.go index ec0bf64b..e7759329 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -72,6 +72,10 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err } +func (u *StdConn) SupportsMultipleReaders() bool { + return true +} + func (u *StdConn) Rebind() error { return nil } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 886e0244..1d602d01 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -315,6 +315,10 @@ func (u *RIOConn) LocalAddr() (netip.AddrPort, error) { } +func (u *RIOConn) SupportsMultipleReaders() bool { + return false +} + func (u *RIOConn) Rebind() error { return nil } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 8d5e6c14..5f0f7765 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -127,6 +127,10 @@ func (u *TesterConn) LocalAddr() (netip.AddrPort, error) { return u.Addr, nil } +func (u *TesterConn) SupportsMultipleReaders() bool { + return false +} + func (u *TesterConn) Rebind() error { return nil }