nebula/bits.go
2026-05-06 10:22:26 -05:00

263 lines
7.3 KiB
Go

package nebula
import (
"context"
"fmt"
"log/slog"
"math"
mathbits "math/bits"
"github.com/rcrowley/go-metrics"
)
const bitsPerWord = 64
// Bits is a sliding-window anti-replay tracker. The window is stored as a
// circular bitmap packed into uint64 words (8x denser than a []bool), so a
// length-N window costs N/8 bytes. length must be a power of two.
type Bits struct {
length uint64
lengthMask uint64
current uint64
bits []uint64
lostCounter metrics.Counter
dupeCounter metrics.Counter
outOfWindowCounter metrics.Counter
}
func NewBits(length uint64) *Bits {
if length == 0 || length&(length-1) != 0 {
panic(fmt.Sprintf("Bits length must be a power of two, got %d", length))
}
nWords := length / bitsPerWord
if nWords == 0 {
nWords = 1
}
b := &Bits{
length: length,
lengthMask: length - 1,
bits: make([]uint64, nWords),
current: 0,
lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil),
dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
}
// There is no counter value 0, mark it to avoid counting a lost packet later.
b.bits[0] = 1
return b
}
func (b *Bits) get(i uint64) bool {
pos := i & b.lengthMask
//bit-shifting by 6 because i is a bit index, not a u64 index, and we need to find the u64 without bit in it
return b.bits[pos>>6]&(uint64(1)<<(pos&63)) != 0
}
func (b *Bits) set(i uint64) {
pos := i & b.lengthMask
b.bits[pos>>6] |= uint64(1) << (pos & 63)
}
// clearRange clears `count` bits starting at circular position `startPos`
// (already masked to [0, length)) and returns how many of them were set
// before the clear. count must be in [1, length].
func (b *Bits) clearRange(startPos, count uint64) uint64 {
wasSet := uint64(0)
if count >= b.length {
for _, w := range b.bits {
wasSet += uint64(mathbits.OnesCount64(w))
}
clear(b.bits)
return wasSet
}
pos := startPos
remaining := count
// handle the potential partial word before pos becomes u64 aligned
word := pos >> 6
bit := pos & 63
take := uint64(64) - bit
if take > remaining {
take = remaining
}
if take > b.length-pos {
take = b.length - pos
}
var mask uint64
if take == 64 {
mask = math.MaxUint64
} else {
mask = ((uint64(1) << take) - 1) << bit
}
wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask))
b.bits[word] &^= mask
remaining -= take
pos = (pos + take) & b.lengthMask
// Clear whole words, keeping track of the number of set bits
for remaining >= 64 {
word = pos >> 6
wasSet += uint64(mathbits.OnesCount64(b.bits[word]))
b.bits[word] = 0
remaining -= 64
pos = (pos + 64) & b.lengthMask
}
// Clear the remaining partial word
if remaining > 0 {
word = pos >> 6
mask = (uint64(1) << remaining) - 1
wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask))
b.bits[word] &^= mask
}
return wasSet
}
func (b *Bits) strictlyWithinWindow(i uint64) bool {
// Handle the case where the window hasn't slid yet. This avoids u64 underflow.
inWarmup := b.current < b.length
if i < b.length && inWarmup {
return true
}
// Next, if the packet is in-window, see if we've seen it before
if i > b.current-b.length {
return true
}
return false //not within window!
}
// Check returns true if i is within (or way out in front of) the window, and not a replay
func (b *Bits) Check(l *slog.Logger, i uint64) bool {
// If i is the next number, return true.
if i > b.current {
return true
}
if b.strictlyWithinWindow(i) {
return !b.get(i)
}
// Not within the window
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("rejected a packet (top)", "current", b.current, "incoming", i)
}
return false
}
// Update has three branches:
// - i == b.current+1: fast path; advance the cursor by one and lose-count
// the slot we just stomped (only past warmup; see the i > b.length guard
// below).
// - i > b.current+1: jump path; clear all slots between current and i
// (or up to a full window's worth, whichever is smaller) via clearRange,
// then mark i. Two arms here: a warmup arm that handles the very first
// window before the cursor has slid, and a steady-state arm that treats
// every cleared empty slot as a lost packet.
// - i <= b.current: in-window check for duplicates; out-of-window otherwise.
//
// NewBits seeds bits[0]=1 so counter 0 looks "received" — Update never
// clears that marker during warmup (clearRange skips position 0 when
// startPos=1), and once b.current >= b.length the marker is no longer
// consulted. The marker prevents a fictitious "lost" hit on the first real
// counter.
func (b *Bits) Update(l *slog.Logger, i uint64) bool {
// Fast path: i is the next expected counter. Split out so the function
// stays small and avoids paying for the slow paths' slog argument-build
// stack frame on every call. The bit read/test/write is inlined to
// touch the backing word once.
if i == b.current+1 {
pos := i & b.lengthMask
word := pos >> 6
mask := uint64(1) << (pos & 63)
w := b.bits[word]
if i > b.length && w&mask == 0 {
b.lostCounter.Inc(1)
}
b.bits[word] = w | mask
b.current = i
return true
}
return b.updateSlow(l, i)
}
// updateSlow handles jumps, in-window backfill, dupes, and out-of-window.
func (b *Bits) updateSlow(l *slog.Logger, i uint64) bool {
// If i is a jump, adjust the window, record lost, update current, and return true
if i > b.current {
end := i
if end > b.current+b.length {
end = b.current + b.length
}
count := end - b.current
startPos := (b.current + 1) & b.lengthMask
var lost int64
if b.current >= b.length {
// Steady state: every cleared slot is past warmup, so any unset
// bit we evict is a lost packet from the previous cycle.
wasSet := b.clearRange(startPos, count)
lost = int64(count) - int64(wasSet)
} else {
// Warmup (the very first window). Some cleared slots represent
// packets <= length where eviction is not "lost" in the usual
// sense. This branch is taken at most once per connection so we
// don't bother optimizing it.
for n := b.current + 1; n <= end; n++ {
if !b.get(n) && n > b.length {
lost++
}
}
b.clearRange(startPos, count)
}
// Anything past the new window can never be backfilled, so it's lost.
if i > b.current+b.length {
lost += int64(i - b.current - b.length)
}
b.lostCounter.Inc(lost)
b.set(i)
b.current = i
return true
}
// If i is within the current window but below the current counter, check to see if it's a duplicate
if b.strictlyWithinWindow(i) {
pos := i & b.lengthMask
word := pos >> 6
mask := uint64(1) << (pos & 63)
w := b.bits[word]
if b.current == i || w&mask != 0 {
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("Receive window",
"accepted", false,
"currentCounter", b.current,
"incomingCounter", i,
"reason", "duplicate",
)
}
b.dupeCounter.Inc(1)
return false
}
b.bits[word] = w | mask
return true
}
// In all other cases, fail and don't change current.
b.outOfWindowCounter.Inc(1)
if l.Enabled(context.Background(), slog.LevelDebug) {
l.Debug("Receive window",
"accepted", false,
"currentCounter", b.current,
"incomingCounter", i,
"reason", "nonsense",
)
}
return false
}