mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-10 22:22:27 -07:00
361 lines
12 KiB
Go
361 lines
12 KiB
Go
package handshake
|
|
|
|
import (
|
|
"bytes"
|
|
"math"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/protobuf/encoding/protowire"
|
|
)
|
|
|
|
func TestPayloadRoundTrip(t *testing.T) {
|
|
t.Run("all fields set", func(t *testing.T) {
|
|
data := MarshalPayload(nil, Payload{
|
|
Cert: []byte("test-cert-bytes"),
|
|
CertVersion: 2,
|
|
InitiatorIndex: 12345,
|
|
ResponderIndex: 67890,
|
|
Time: 1234567890,
|
|
})
|
|
|
|
got, err := UnmarshalPayload(data)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, []byte("test-cert-bytes"), got.Cert)
|
|
assert.Equal(t, uint32(12345), got.InitiatorIndex)
|
|
assert.Equal(t, uint32(67890), got.ResponderIndex)
|
|
assert.Equal(t, uint64(1234567890), got.Time)
|
|
assert.Equal(t, uint32(2), got.CertVersion)
|
|
})
|
|
|
|
t.Run("minimal fields", func(t *testing.T) {
|
|
data := MarshalPayload(nil, Payload{InitiatorIndex: 1})
|
|
|
|
got, err := UnmarshalPayload(data)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, uint32(1), got.InitiatorIndex)
|
|
assert.Equal(t, uint32(0), got.ResponderIndex)
|
|
assert.Equal(t, uint64(0), got.Time)
|
|
assert.Nil(t, got.Cert)
|
|
})
|
|
|
|
t.Run("empty payload", func(t *testing.T) {
|
|
data := MarshalPayload(nil, Payload{})
|
|
|
|
got, err := UnmarshalPayload(data)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, uint32(0), got.InitiatorIndex)
|
|
})
|
|
|
|
t.Run("large cert bytes", func(t *testing.T) {
|
|
bigCert := make([]byte, 4096)
|
|
for i := range bigCert {
|
|
bigCert[i] = byte(i % 256)
|
|
}
|
|
|
|
data := MarshalPayload(nil, Payload{
|
|
Cert: bigCert,
|
|
CertVersion: 2,
|
|
InitiatorIndex: 999,
|
|
})
|
|
|
|
got, err := UnmarshalPayload(data)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, bigCert, got.Cert)
|
|
assert.Equal(t, uint32(999), got.InitiatorIndex)
|
|
})
|
|
|
|
t.Run("append to existing buffer", func(t *testing.T) {
|
|
prefix := []byte("prefix")
|
|
data := MarshalPayload(prefix, Payload{InitiatorIndex: 42})
|
|
|
|
assert.Equal(t, []byte("prefix"), data[:6])
|
|
|
|
got, err := UnmarshalPayload(data[6:])
|
|
require.NoError(t, err)
|
|
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
|
})
|
|
}
|
|
|
|
func TestPayloadUnknownFields(t *testing.T) {
|
|
t.Run("unknown field in outer message is skipped", func(t *testing.T) {
|
|
// Marshal a normal payload then append an unknown field (field 99, varint)
|
|
data := MarshalPayload(nil, Payload{InitiatorIndex: 42})
|
|
data = protowire.AppendTag(data, 99, protowire.VarintType)
|
|
data = protowire.AppendVarint(data, 12345)
|
|
|
|
got, err := UnmarshalPayload(data)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
|
})
|
|
|
|
t.Run("unknown field in details is skipped", func(t *testing.T) {
|
|
// Build details with a known field + unknown field
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
|
details = protowire.AppendVarint(details, 77)
|
|
// Unknown field 50, varint
|
|
details = protowire.AppendTag(details, 50, protowire.VarintType)
|
|
details = protowire.AppendVarint(details, 9999)
|
|
// Another known field after the unknown one
|
|
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
|
|
details = protowire.AppendVarint(details, 88)
|
|
|
|
// Wrap in outer message
|
|
var data []byte
|
|
data = protowire.AppendTag(data, 1, protowire.BytesType)
|
|
data = protowire.AppendBytes(data, details)
|
|
|
|
got, err := UnmarshalPayload(data)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, uint32(77), got.InitiatorIndex)
|
|
assert.Equal(t, uint32(88), got.ResponderIndex)
|
|
})
|
|
|
|
t.Run("reserved fields 6 and 7 are skipped", func(t *testing.T) {
|
|
// Fields 6 and 7 are reserved in the proto definition
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
|
details = protowire.AppendVarint(details, 100)
|
|
details = protowire.AppendTag(details, 6, protowire.VarintType)
|
|
details = protowire.AppendVarint(details, 1)
|
|
details = protowire.AppendTag(details, 7, protowire.VarintType)
|
|
details = protowire.AppendVarint(details, 2)
|
|
|
|
var data []byte
|
|
data = protowire.AppendTag(data, 1, protowire.BytesType)
|
|
data = protowire.AppendBytes(data, details)
|
|
|
|
got, err := UnmarshalPayload(data)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, uint32(100), got.InitiatorIndex)
|
|
})
|
|
}
|
|
|
|
func TestPayloadBytesConsumed(t *testing.T) {
|
|
t.Run("all bytes consumed on valid input", func(t *testing.T) {
|
|
original := Payload{
|
|
Cert: []byte("cert"),
|
|
CertVersion: 2,
|
|
InitiatorIndex: 100,
|
|
ResponderIndex: 200,
|
|
Time: 999,
|
|
}
|
|
data := MarshalPayload(nil, original)
|
|
|
|
got, err := UnmarshalPayload(data)
|
|
require.NoError(t, err)
|
|
|
|
// Re-marshal and compare — proves we consumed and reproduced all fields
|
|
remarshaled := MarshalPayload(nil, got)
|
|
assert.Equal(t, data, remarshaled)
|
|
})
|
|
}
|
|
|
|
// wrapDetails wraps raw detail bytes in the outer NebulaHandshake envelope
|
|
// so UnmarshalPayload can reach unmarshalPayloadDetails.
|
|
func wrapDetails(details []byte) []byte {
|
|
var out []byte
|
|
out = protowire.AppendTag(out, 1, protowire.BytesType)
|
|
out = protowire.AppendBytes(out, details)
|
|
return out
|
|
}
|
|
|
|
func TestPayloadUnmarshalErrors(t *testing.T) {
|
|
t.Run("nil input", func(t *testing.T) {
|
|
got, err := UnmarshalPayload(nil)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, uint32(0), got.InitiatorIndex)
|
|
})
|
|
|
|
t.Run("truncated outer tag", func(t *testing.T) {
|
|
_, err := UnmarshalPayload([]byte{0x80})
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("truncated outer details field", func(t *testing.T) {
|
|
_, err := UnmarshalPayload([]byte{0x0a, 0x64, 0x01, 0x02, 0x03, 0x04, 0x05})
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("truncated outer unknown field", func(t *testing.T) {
|
|
// Valid tag for unknown field 99 varint, but no value follows
|
|
var data []byte
|
|
data = protowire.AppendTag(data, 99, protowire.VarintType)
|
|
_, err := UnmarshalPayload(data)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("truncated details tag", func(t *testing.T) {
|
|
_, err := UnmarshalPayload(wrapDetails([]byte{0x80}))
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("truncated cert bytes", func(t *testing.T) {
|
|
// Field 1 (cert), bytes type, length 10 but only 2 bytes
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
|
|
details = append(details, 0x0a, 0x01, 0x02) // length 10, only 2 bytes
|
|
_, err := UnmarshalPayload(wrapDetails(details))
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("truncated initiator index varint", func(t *testing.T) {
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
|
details = append(details, 0x80) // incomplete varint
|
|
_, err := UnmarshalPayload(wrapDetails(details))
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("truncated responder index varint", func(t *testing.T) {
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
|
|
details = append(details, 0x80)
|
|
_, err := UnmarshalPayload(wrapDetails(details))
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("truncated time varint", func(t *testing.T) {
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
|
|
details = append(details, 0x80)
|
|
_, err := UnmarshalPayload(wrapDetails(details))
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("truncated cert version varint", func(t *testing.T) {
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
|
details = append(details, 0x80)
|
|
_, err := UnmarshalPayload(wrapDetails(details))
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("truncated unknown field in details", func(t *testing.T) {
|
|
var details []byte
|
|
details = protowire.AppendTag(details, 50, protowire.VarintType)
|
|
details = append(details, 0x80) // incomplete varint
|
|
_, err := UnmarshalPayload(wrapDetails(details))
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("cert with wrong wire type rejected", func(t *testing.T) {
|
|
// fieldCert as Varint instead of Bytes.
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldCert, protowire.VarintType)
|
|
details = protowire.AppendVarint(details, 42)
|
|
_, err := UnmarshalPayload(wrapDetails(details))
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("initiator index with wrong wire type rejected", func(t *testing.T) {
|
|
// fieldInitiatorIndex as Bytes instead of Varint.
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.BytesType)
|
|
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
|
_, err := UnmarshalPayload(wrapDetails(details))
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("time with wrong wire type rejected", func(t *testing.T) {
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldTime, protowire.BytesType)
|
|
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
|
_, err := UnmarshalPayload(wrapDetails(details))
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("cert version with wrong wire type rejected", func(t *testing.T) {
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldCertVersion, protowire.BytesType)
|
|
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
|
_, err := UnmarshalPayload(wrapDetails(details))
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("repeated singular field follows proto3 last-wins", func(t *testing.T) {
|
|
// Per proto3, multiple instances of a singular field are accepted and
|
|
// the last value wins. We keep this behavior so that peers using
|
|
// alternative encoders aren't rejected.
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
|
details = protowire.AppendVarint(details, 1)
|
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
|
details = protowire.AppendVarint(details, 42)
|
|
got, err := UnmarshalPayload(wrapDetails(details))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
|
})
|
|
|
|
t.Run("initiator index varint overflow rejected", func(t *testing.T) {
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
|
details = protowire.AppendVarint(details, math.MaxUint32+1)
|
|
_, err := UnmarshalPayload(wrapDetails(details))
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("cert version varint overflow rejected", func(t *testing.T) {
|
|
var details []byte
|
|
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
|
details = protowire.AppendVarint(details, math.MaxUint32+1)
|
|
_, err := UnmarshalPayload(wrapDetails(details))
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
}
|
|
|
|
// FuzzPayload feeds arbitrary bytes through UnmarshalPayload to confirm it
|
|
// never panics, and for any input that parses cleanly, that re-marshal +
|
|
// re-parse is a fix-point. Inputs come from an authenticated peer (post-
|
|
// noise-decrypt), so the threat model is "valid peer behaving arbitrarily,"
|
|
// not "unauthenticated injection."
|
|
func FuzzPayload(f *testing.F) {
|
|
// Seed corpus with a handful of known-good shapes.
|
|
f.Add(MarshalPayload(nil, Payload{}))
|
|
f.Add(MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}))
|
|
f.Add(MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1}))
|
|
f.Add(MarshalPayload(nil, Payload{
|
|
Cert: []byte("seed-cert"),
|
|
InitiatorIndex: 1,
|
|
ResponderIndex: 2,
|
|
Time: 3,
|
|
CertVersion: 2,
|
|
}))
|
|
f.Add([]byte{})
|
|
f.Add([]byte{0xff})
|
|
|
|
f.Fuzz(func(t *testing.T, data []byte) {
|
|
p1, err := UnmarshalPayload(data)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// For any input that parses, re-marshaling and re-parsing must
|
|
// yield an equivalent Payload. This catches dispatch bugs (e.g.
|
|
// emitting a field on marshal that we don't accept on parse) and
|
|
// any non-idempotent parsing behavior.
|
|
b2 := MarshalPayload(nil, p1)
|
|
p2, err := UnmarshalPayload(b2)
|
|
if err != nil {
|
|
t.Fatalf("re-parse of self-marshaled payload failed: %v\nintermediate: %x\n", err, b2)
|
|
}
|
|
if !payloadsEqual(p1, p2) {
|
|
t.Fatalf("re-marshal not idempotent\nfirst: %+v\nsecond: %+v", p1, p2)
|
|
}
|
|
})
|
|
}
|
|
|
|
func payloadsEqual(a, b Payload) bool {
|
|
return bytes.Equal(a.Cert, b.Cert) &&
|
|
a.InitiatorIndex == b.InitiatorIndex &&
|
|
a.ResponderIndex == b.ResponderIndex &&
|
|
a.Time == b.Time &&
|
|
a.CertVersion == b.CertVersion
|
|
}
|