nebula/handshake/patterns_test.go
2026-04-30 21:30:27 -05:00

63 lines
1.8 KiB
Go

package handshake
import (
"testing"
"github.com/flynn/noise"
"github.com/slackhq/nebula/header"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSubtypeInfo(t *testing.T) {
t.Run("IX", func(t *testing.T) {
info, err := subtypeInfoFor(header.HandshakeIXPSK0)
require.NoError(t, err)
assert.Equal(t, noise.HandshakeIX.Name, info.pattern.Name)
require.Len(t, info.msgs, 2)
// msg1: payload + cert
assert.True(t, info.msgs[0].expectsPayload)
assert.True(t, info.msgs[0].expectsCert)
// msg2: payload + cert
assert.True(t, info.msgs[1].expectsPayload)
assert.True(t, info.msgs[1].expectsCert)
})
t.Run("XX", func(t *testing.T) {
registerTestXXInfo(t)
info, err := subtypeInfoFor(header.HandshakeXXPSK0)
require.NoError(t, err)
assert.Equal(t, noise.HandshakeXX.Name, info.pattern.Name)
require.Len(t, info.msgs, 3)
// msg1: payload only
assert.True(t, info.msgs[0].expectsPayload)
assert.False(t, info.msgs[0].expectsCert)
// msg2: payload + cert
assert.True(t, info.msgs[1].expectsPayload)
assert.True(t, info.msgs[1].expectsCert)
// msg3: cert only
assert.False(t, info.msgs[2].expectsPayload)
assert.True(t, info.msgs[2].expectsCert)
})
t.Run("unknown subtype returns error", func(t *testing.T) {
_, err := subtypeInfoFor(99)
require.ErrorIs(t, err, ErrUnknownSubtype)
})
}
// registerTestXXInfo temporarily registers XX subtype info for testing.
func registerTestXXInfo(t *testing.T) {
t.Helper()
subtypeInfos[header.HandshakeXXPSK0] = subtypeInfo{
pattern: noise.HandshakeXX,
msgs: []msgFlags{
{expectsPayload: true, expectsCert: false},
{expectsPayload: true, expectsCert: true},
{expectsPayload: false, expectsCert: true},
},
}
t.Cleanup(func() {
delete(subtypeInfos, header.HandshakeXXPSK0)
})
}