mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-10 22:22:27 -07:00
263 lines
7.3 KiB
Go
263 lines
7.3 KiB
Go
package nebula
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"math"
|
|
mathbits "math/bits"
|
|
|
|
"github.com/rcrowley/go-metrics"
|
|
)
|
|
|
|
const bitsPerWord = 64
|
|
|
|
// Bits is a sliding-window anti-replay tracker. The window is stored as a
|
|
// circular bitmap packed into uint64 words (8x denser than a []bool), so a
|
|
// length-N window costs N/8 bytes. length must be a power of two.
|
|
type Bits struct {
|
|
length uint64
|
|
lengthMask uint64
|
|
current uint64
|
|
bits []uint64
|
|
lostCounter metrics.Counter
|
|
dupeCounter metrics.Counter
|
|
outOfWindowCounter metrics.Counter
|
|
}
|
|
|
|
func NewBits(length uint64) *Bits {
|
|
if length == 0 || length&(length-1) != 0 {
|
|
panic(fmt.Sprintf("Bits length must be a power of two, got %d", length))
|
|
}
|
|
|
|
nWords := length / bitsPerWord
|
|
if nWords == 0 {
|
|
nWords = 1
|
|
}
|
|
b := &Bits{
|
|
length: length,
|
|
lengthMask: length - 1,
|
|
bits: make([]uint64, nWords),
|
|
current: 0,
|
|
lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil),
|
|
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
|
|
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
|
|
}
|
|
|
|
// There is no counter value 0, mark it to avoid counting a lost packet later.
|
|
b.bits[0] = 1
|
|
return b
|
|
}
|
|
|
|
func (b *Bits) get(i uint64) bool {
|
|
pos := i & b.lengthMask
|
|
//bit-shifting by 6 because i is a bit index, not a u64 index, and we need to find the u64 without bit in it
|
|
return b.bits[pos>>6]&(uint64(1)<<(pos&63)) != 0
|
|
}
|
|
|
|
func (b *Bits) set(i uint64) {
|
|
pos := i & b.lengthMask
|
|
b.bits[pos>>6] |= uint64(1) << (pos & 63)
|
|
}
|
|
|
|
// clearRange clears `count` bits starting at circular position `startPos`
|
|
// (already masked to [0, length)) and returns how many of them were set
|
|
// before the clear. count must be in [1, length].
|
|
func (b *Bits) clearRange(startPos, count uint64) uint64 {
|
|
wasSet := uint64(0)
|
|
if count >= b.length {
|
|
for _, w := range b.bits {
|
|
wasSet += uint64(mathbits.OnesCount64(w))
|
|
}
|
|
clear(b.bits)
|
|
return wasSet
|
|
}
|
|
|
|
pos := startPos
|
|
remaining := count
|
|
|
|
// handle the potential partial word before pos becomes u64 aligned
|
|
word := pos >> 6
|
|
bit := pos & 63
|
|
take := uint64(64) - bit
|
|
if take > remaining {
|
|
take = remaining
|
|
}
|
|
if take > b.length-pos {
|
|
take = b.length - pos
|
|
}
|
|
var mask uint64
|
|
if take == 64 {
|
|
mask = math.MaxUint64
|
|
} else {
|
|
mask = ((uint64(1) << take) - 1) << bit
|
|
}
|
|
wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask))
|
|
b.bits[word] &^= mask
|
|
remaining -= take
|
|
pos = (pos + take) & b.lengthMask
|
|
|
|
// Clear whole words, keeping track of the number of set bits
|
|
for remaining >= 64 {
|
|
word = pos >> 6
|
|
wasSet += uint64(mathbits.OnesCount64(b.bits[word]))
|
|
b.bits[word] = 0
|
|
remaining -= 64
|
|
pos = (pos + 64) & b.lengthMask
|
|
}
|
|
|
|
// Clear the remaining partial word
|
|
if remaining > 0 {
|
|
word = pos >> 6
|
|
mask = (uint64(1) << remaining) - 1
|
|
wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask))
|
|
b.bits[word] &^= mask
|
|
}
|
|
|
|
return wasSet
|
|
}
|
|
|
|
func (b *Bits) strictlyWithinWindow(i uint64) bool {
|
|
// Handle the case where the window hasn't slid yet. This avoids u64 underflow.
|
|
inWarmup := b.current < b.length
|
|
if i < b.length && inWarmup {
|
|
return true
|
|
}
|
|
|
|
// Next, if the packet is in-window, see if we've seen it before
|
|
if i > b.current-b.length {
|
|
return true
|
|
}
|
|
return false //not within window!
|
|
}
|
|
|
|
// Check returns true if i is within (or way out in front of) the window, and not a replay
|
|
func (b *Bits) Check(l *slog.Logger, i uint64) bool {
|
|
// If i is the next number, return true.
|
|
if i > b.current {
|
|
return true
|
|
}
|
|
|
|
if b.strictlyWithinWindow(i) {
|
|
return !b.get(i)
|
|
}
|
|
|
|
// Not within the window
|
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
|
l.Debug("rejected a packet (top)", "current", b.current, "incoming", i)
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Update has three branches:
|
|
// - i == b.current+1: fast path; advance the cursor by one and lose-count
|
|
// the slot we just stomped (only past warmup; see the i > b.length guard
|
|
// below).
|
|
// - i > b.current+1: jump path; clear all slots between current and i
|
|
// (or up to a full window's worth, whichever is smaller) via clearRange,
|
|
// then mark i. Two arms here: a warmup arm that handles the very first
|
|
// window before the cursor has slid, and a steady-state arm that treats
|
|
// every cleared empty slot as a lost packet.
|
|
// - i <= b.current: in-window check for duplicates; out-of-window otherwise.
|
|
//
|
|
// NewBits seeds bits[0]=1 so counter 0 looks "received" — Update never
|
|
// clears that marker during warmup (clearRange skips position 0 when
|
|
// startPos=1), and once b.current >= b.length the marker is no longer
|
|
// consulted. The marker prevents a fictitious "lost" hit on the first real
|
|
// counter.
|
|
func (b *Bits) Update(l *slog.Logger, i uint64) bool {
|
|
// Fast path: i is the next expected counter. Split out so the function
|
|
// stays small and avoids paying for the slow paths' slog argument-build
|
|
// stack frame on every call. The bit read/test/write is inlined to
|
|
// touch the backing word once.
|
|
if i == b.current+1 {
|
|
pos := i & b.lengthMask
|
|
word := pos >> 6
|
|
mask := uint64(1) << (pos & 63)
|
|
w := b.bits[word]
|
|
if i > b.length && w&mask == 0 {
|
|
b.lostCounter.Inc(1)
|
|
}
|
|
b.bits[word] = w | mask
|
|
b.current = i
|
|
return true
|
|
}
|
|
return b.updateSlow(l, i)
|
|
}
|
|
|
|
// updateSlow handles jumps, in-window backfill, dupes, and out-of-window.
|
|
func (b *Bits) updateSlow(l *slog.Logger, i uint64) bool {
|
|
// If i is a jump, adjust the window, record lost, update current, and return true
|
|
if i > b.current {
|
|
end := i
|
|
if end > b.current+b.length {
|
|
end = b.current + b.length
|
|
}
|
|
count := end - b.current
|
|
startPos := (b.current + 1) & b.lengthMask
|
|
|
|
var lost int64
|
|
if b.current >= b.length {
|
|
// Steady state: every cleared slot is past warmup, so any unset
|
|
// bit we evict is a lost packet from the previous cycle.
|
|
wasSet := b.clearRange(startPos, count)
|
|
lost = int64(count) - int64(wasSet)
|
|
} else {
|
|
// Warmup (the very first window). Some cleared slots represent
|
|
// packets <= length where eviction is not "lost" in the usual
|
|
// sense. This branch is taken at most once per connection so we
|
|
// don't bother optimizing it.
|
|
for n := b.current + 1; n <= end; n++ {
|
|
if !b.get(n) && n > b.length {
|
|
lost++
|
|
}
|
|
}
|
|
b.clearRange(startPos, count)
|
|
}
|
|
|
|
// Anything past the new window can never be backfilled, so it's lost.
|
|
if i > b.current+b.length {
|
|
lost += int64(i - b.current - b.length)
|
|
}
|
|
b.lostCounter.Inc(lost)
|
|
|
|
b.set(i)
|
|
b.current = i
|
|
return true
|
|
}
|
|
|
|
// If i is within the current window but below the current counter, check to see if it's a duplicate
|
|
if b.strictlyWithinWindow(i) {
|
|
pos := i & b.lengthMask
|
|
word := pos >> 6
|
|
mask := uint64(1) << (pos & 63)
|
|
w := b.bits[word]
|
|
if b.current == i || w&mask != 0 {
|
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
|
l.Debug("Receive window",
|
|
"accepted", false,
|
|
"currentCounter", b.current,
|
|
"incomingCounter", i,
|
|
"reason", "duplicate",
|
|
)
|
|
}
|
|
b.dupeCounter.Inc(1)
|
|
return false
|
|
}
|
|
|
|
b.bits[word] = w | mask
|
|
return true
|
|
}
|
|
|
|
// In all other cases, fail and don't change current.
|
|
b.outOfWindowCounter.Inc(1)
|
|
if l.Enabled(context.Background(), slog.LevelDebug) {
|
|
l.Debug("Receive window",
|
|
"accepted", false,
|
|
"currentCounter", b.current,
|
|
"incomingCounter", i,
|
|
"reason", "nonsense",
|
|
)
|
|
}
|
|
return false
|
|
}
|