nebula/handshake/payload_test.go
2026-04-30 21:30:27 -05:00

361 lines
12 KiB
Go

package handshake
import (
"bytes"
"math"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/protowire"
)
func TestPayloadRoundTrip(t *testing.T) {
t.Run("all fields set", func(t *testing.T) {
data := MarshalPayload(nil, Payload{
Cert: []byte("test-cert-bytes"),
CertVersion: 2,
InitiatorIndex: 12345,
ResponderIndex: 67890,
Time: 1234567890,
})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, []byte("test-cert-bytes"), got.Cert)
assert.Equal(t, uint32(12345), got.InitiatorIndex)
assert.Equal(t, uint32(67890), got.ResponderIndex)
assert.Equal(t, uint64(1234567890), got.Time)
assert.Equal(t, uint32(2), got.CertVersion)
})
t.Run("minimal fields", func(t *testing.T) {
data := MarshalPayload(nil, Payload{InitiatorIndex: 1})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(1), got.InitiatorIndex)
assert.Equal(t, uint32(0), got.ResponderIndex)
assert.Equal(t, uint64(0), got.Time)
assert.Nil(t, got.Cert)
})
t.Run("empty payload", func(t *testing.T) {
data := MarshalPayload(nil, Payload{})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(0), got.InitiatorIndex)
})
t.Run("large cert bytes", func(t *testing.T) {
bigCert := make([]byte, 4096)
for i := range bigCert {
bigCert[i] = byte(i % 256)
}
data := MarshalPayload(nil, Payload{
Cert: bigCert,
CertVersion: 2,
InitiatorIndex: 999,
})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, bigCert, got.Cert)
assert.Equal(t, uint32(999), got.InitiatorIndex)
})
t.Run("append to existing buffer", func(t *testing.T) {
prefix := []byte("prefix")
data := MarshalPayload(prefix, Payload{InitiatorIndex: 42})
assert.Equal(t, []byte("prefix"), data[:6])
got, err := UnmarshalPayload(data[6:])
require.NoError(t, err)
assert.Equal(t, uint32(42), got.InitiatorIndex)
})
}
func TestPayloadUnknownFields(t *testing.T) {
t.Run("unknown field in outer message is skipped", func(t *testing.T) {
// Marshal a normal payload then append an unknown field (field 99, varint)
data := MarshalPayload(nil, Payload{InitiatorIndex: 42})
data = protowire.AppendTag(data, 99, protowire.VarintType)
data = protowire.AppendVarint(data, 12345)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(42), got.InitiatorIndex)
})
t.Run("unknown field in details is skipped", func(t *testing.T) {
// Build details with a known field + unknown field
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 77)
// Unknown field 50, varint
details = protowire.AppendTag(details, 50, protowire.VarintType)
details = protowire.AppendVarint(details, 9999)
// Another known field after the unknown one
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 88)
// Wrap in outer message
var data []byte
data = protowire.AppendTag(data, 1, protowire.BytesType)
data = protowire.AppendBytes(data, details)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(77), got.InitiatorIndex)
assert.Equal(t, uint32(88), got.ResponderIndex)
})
t.Run("reserved fields 6 and 7 are skipped", func(t *testing.T) {
// Fields 6 and 7 are reserved in the proto definition
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 100)
details = protowire.AppendTag(details, 6, protowire.VarintType)
details = protowire.AppendVarint(details, 1)
details = protowire.AppendTag(details, 7, protowire.VarintType)
details = protowire.AppendVarint(details, 2)
var data []byte
data = protowire.AppendTag(data, 1, protowire.BytesType)
data = protowire.AppendBytes(data, details)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(100), got.InitiatorIndex)
})
}
func TestPayloadBytesConsumed(t *testing.T) {
t.Run("all bytes consumed on valid input", func(t *testing.T) {
original := Payload{
Cert: []byte("cert"),
CertVersion: 2,
InitiatorIndex: 100,
ResponderIndex: 200,
Time: 999,
}
data := MarshalPayload(nil, original)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
// Re-marshal and compare — proves we consumed and reproduced all fields
remarshaled := MarshalPayload(nil, got)
assert.Equal(t, data, remarshaled)
})
}
// wrapDetails wraps raw detail bytes in the outer NebulaHandshake envelope
// so UnmarshalPayload can reach unmarshalPayloadDetails.
func wrapDetails(details []byte) []byte {
var out []byte
out = protowire.AppendTag(out, 1, protowire.BytesType)
out = protowire.AppendBytes(out, details)
return out
}
func TestPayloadUnmarshalErrors(t *testing.T) {
t.Run("nil input", func(t *testing.T) {
got, err := UnmarshalPayload(nil)
require.NoError(t, err)
assert.Equal(t, uint32(0), got.InitiatorIndex)
})
t.Run("truncated outer tag", func(t *testing.T) {
_, err := UnmarshalPayload([]byte{0x80})
assert.Error(t, err)
})
t.Run("truncated outer details field", func(t *testing.T) {
_, err := UnmarshalPayload([]byte{0x0a, 0x64, 0x01, 0x02, 0x03, 0x04, 0x05})
assert.Error(t, err)
})
t.Run("truncated outer unknown field", func(t *testing.T) {
// Valid tag for unknown field 99 varint, but no value follows
var data []byte
data = protowire.AppendTag(data, 99, protowire.VarintType)
_, err := UnmarshalPayload(data)
assert.Error(t, err)
})
t.Run("truncated details tag", func(t *testing.T) {
_, err := UnmarshalPayload(wrapDetails([]byte{0x80}))
assert.Error(t, err)
})
t.Run("truncated cert bytes", func(t *testing.T) {
// Field 1 (cert), bytes type, length 10 but only 2 bytes
var details []byte
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
details = append(details, 0x0a, 0x01, 0x02) // length 10, only 2 bytes
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated initiator index varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = append(details, 0x80) // incomplete varint
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated responder index varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
details = append(details, 0x80)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated time varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
details = append(details, 0x80)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated cert version varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
details = append(details, 0x80)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated unknown field in details", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, 50, protowire.VarintType)
details = append(details, 0x80) // incomplete varint
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("cert with wrong wire type rejected", func(t *testing.T) {
// fieldCert as Varint instead of Bytes.
var details []byte
details = protowire.AppendTag(details, fieldCert, protowire.VarintType)
details = protowire.AppendVarint(details, 42)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("initiator index with wrong wire type rejected", func(t *testing.T) {
// fieldInitiatorIndex as Bytes instead of Varint.
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.BytesType)
details = protowire.AppendBytes(details, []byte{1, 2, 3})
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("time with wrong wire type rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldTime, protowire.BytesType)
details = protowire.AppendBytes(details, []byte{1, 2, 3})
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("cert version with wrong wire type rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldCertVersion, protowire.BytesType)
details = protowire.AppendBytes(details, []byte{1, 2, 3})
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("repeated singular field follows proto3 last-wins", func(t *testing.T) {
// Per proto3, multiple instances of a singular field are accepted and
// the last value wins. We keep this behavior so that peers using
// alternative encoders aren't rejected.
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 1)
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 42)
got, err := UnmarshalPayload(wrapDetails(details))
require.NoError(t, err)
assert.Equal(t, uint32(42), got.InitiatorIndex)
})
t.Run("initiator index varint overflow rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, math.MaxUint32+1)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("cert version varint overflow rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
details = protowire.AppendVarint(details, math.MaxUint32+1)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
}
// FuzzPayload feeds arbitrary bytes through UnmarshalPayload to confirm it
// never panics, and for any input that parses cleanly, that re-marshal +
// re-parse is a fix-point. Inputs come from an authenticated peer (post-
// noise-decrypt), so the threat model is "valid peer behaving arbitrarily,"
// not "unauthenticated injection."
func FuzzPayload(f *testing.F) {
// Seed corpus with a handful of known-good shapes.
f.Add(MarshalPayload(nil, Payload{}))
f.Add(MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}))
f.Add(MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1}))
f.Add(MarshalPayload(nil, Payload{
Cert: []byte("seed-cert"),
InitiatorIndex: 1,
ResponderIndex: 2,
Time: 3,
CertVersion: 2,
}))
f.Add([]byte{})
f.Add([]byte{0xff})
f.Fuzz(func(t *testing.T, data []byte) {
p1, err := UnmarshalPayload(data)
if err != nil {
return
}
// For any input that parses, re-marshaling and re-parsing must
// yield an equivalent Payload. This catches dispatch bugs (e.g.
// emitting a field on marshal that we don't accept on parse) and
// any non-idempotent parsing behavior.
b2 := MarshalPayload(nil, p1)
p2, err := UnmarshalPayload(b2)
if err != nil {
t.Fatalf("re-parse of self-marshaled payload failed: %v\nintermediate: %x\n", err, b2)
}
if !payloadsEqual(p1, p2) {
t.Fatalf("re-marshal not idempotent\nfirst: %+v\nsecond: %+v", p1, p2)
}
})
}
func payloadsEqual(a, b Payload) bool {
return bytes.Equal(a.Cert, b.Cert) &&
a.InitiatorIndex == b.InitiatorIndex &&
a.ResponderIndex == b.ResponderIndex &&
a.Time == b.Time &&
a.CertVersion == b.CertVersion
}