mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-10 22:22:27 -07:00
Some checks are pending
gofmt / Run gofmt (push) Waiting to run
smoke-extra / Run extra smoke tests (push) Waiting to run
smoke / Run multi node smoke test (push) Waiting to run
Build and test / Build all and test on ubuntu-linux (push) Waiting to run
Build and test / Build and test on linux with boringcrypto (push) Waiting to run
Build and test / Build and test on linux with pkcs11 (push) Waiting to run
Build and test / Build and test on macos-latest (push) Waiting to run
Build and test / Build and test on windows-latest (push) Waiting to run
446 lines
13 KiB
Go
446 lines
13 KiB
Go
package handshake
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"slices"
|
|
"time"
|
|
|
|
"github.com/flynn/noise"
|
|
"github.com/slackhq/nebula/cert"
|
|
"github.com/slackhq/nebula/header"
|
|
)
|
|
|
|
// IndexAllocator is called by the Machine to allocate a local index for the
|
|
// handshake. It is called at most once, when the first outgoing message that
|
|
// carries a payload is built.
|
|
//
|
|
// Implementations MUST NOT return 0. Zero is reserved as a sentinel meaning
|
|
// "no index assigned" on the wire and in the payload-presence checks. If an
|
|
// allocator ever returned 0, a legitimate handshake's payload could be
|
|
// indistinguishable from an empty one and would be rejected.
|
|
type IndexAllocator func() (uint32, error)
|
|
|
|
// CertVerifier is called by the Machine after reconstructing the peer's
|
|
// certificate from the handshake. The verifier performs all validation
|
|
// (CA trust, expiry, policy checks, allow lists).
|
|
type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error)
|
|
|
|
// Result contains the results of a successful handshake.
|
|
// Returned by ProcessPacket when the handshake is complete.
|
|
type Result struct {
|
|
EKey *noise.CipherState
|
|
DKey *noise.CipherState
|
|
Cipher noise.CipherFunc // identifies which post-handshake CipherState the data plane should wrap EKey/DKey in
|
|
MyCert cert.Certificate
|
|
RemoteCert *cert.CachedCertificate
|
|
RemoteIndex uint32
|
|
LocalIndex uint32
|
|
HandshakeTime uint64
|
|
MessageIndex uint64 // number of messages exchanged during the handshake
|
|
Initiator bool
|
|
}
|
|
|
|
// Machine drives a Noise handshake through N messages. It handles Noise
|
|
// protocol operations, certificate reconstruction, and payload encoding.
|
|
// Certificate validation is delegated to the caller via CertVerifier.
|
|
//
|
|
// A Machine is not safe for concurrent use. The caller must ensure that
|
|
// Initiate and ProcessPacket are not called concurrently.
|
|
//
|
|
// Error contract: when ProcessPacket or Initiate returns an error, callers
|
|
// must check Failed() to decide what to do next. If Failed() is false the
|
|
// underlying noise state was not advanced (the packet was rejected before
|
|
// ReadMessage took effect, or the rejection is non-fatal like a stale
|
|
// retransmit) and the Machine can accept another packet. If Failed() is
|
|
// true the Machine is unrecoverable and the caller must abandon it.
|
|
type Machine struct {
|
|
hs *noise.HandshakeState
|
|
getCred GetCredentialFunc
|
|
allocIndex IndexAllocator
|
|
verifier CertVerifier
|
|
result *Result
|
|
msgs []msgFlags
|
|
myVersion cert.Version
|
|
subtype header.MessageSubType
|
|
indexAllocated bool
|
|
remoteCertSet bool
|
|
payloadSet bool
|
|
failed bool
|
|
}
|
|
|
|
// NewMachine creates a handshake state machine. The subtype determines both
|
|
// the noise pattern and the per-message content layout. The credential for
|
|
// `version` is fetched via getCred and used to seed the noise.HandshakeState.
|
|
// IndexAllocator is called lazily when the first outgoing payload is built.
|
|
func NewMachine(
|
|
version cert.Version,
|
|
getCred GetCredentialFunc,
|
|
verifier CertVerifier,
|
|
allocIndex IndexAllocator,
|
|
initiator bool,
|
|
subtype header.MessageSubType,
|
|
) (*Machine, error) {
|
|
info, err := subtypeInfoFor(subtype)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cred := getCred(version)
|
|
if cred == nil {
|
|
return nil, fmt.Errorf("%w: %v", ErrNoCredential, version)
|
|
}
|
|
|
|
hs, err := cred.buildHandshakeState(initiator, info.pattern)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("build noise state: %w", err)
|
|
}
|
|
|
|
return &Machine{
|
|
hs: hs,
|
|
subtype: subtype,
|
|
msgs: info.msgs,
|
|
getCred: getCred,
|
|
allocIndex: allocIndex,
|
|
verifier: verifier,
|
|
myVersion: version,
|
|
result: &Result{
|
|
Initiator: initiator,
|
|
Cipher: cred.cipherSuite,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// Failed returns true if the Machine is in an unrecoverable state.
|
|
func (m *Machine) Failed() bool {
|
|
return m.failed
|
|
}
|
|
|
|
// Subtype returns the handshake subtype this Machine was built for.
|
|
func (m *Machine) Subtype() header.MessageSubType {
|
|
return m.subtype
|
|
}
|
|
|
|
// MessageIndex returns the noise handshake message index, which equals the
|
|
// wire counter of the most recently sent or received message.
|
|
func (m *Machine) MessageIndex() int {
|
|
return m.hs.MessageIndex()
|
|
}
|
|
|
|
// requireComplete checks that both a peer cert and payload have been received.
|
|
// Marks the machine as failed if not.
|
|
func (m *Machine) requireComplete() error {
|
|
if !m.payloadSet || !m.remoteCertSet {
|
|
m.failed = true
|
|
return ErrIncompleteHandshake
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// myMsgFlags returns the flags for the current outgoing message.
|
|
func (m *Machine) myMsgFlags() msgFlags {
|
|
idx := m.hs.MessageIndex()
|
|
if idx < len(m.msgs) {
|
|
return m.msgs[idx]
|
|
}
|
|
return msgFlags{}
|
|
}
|
|
|
|
// peerMsgFlags returns the flags for the message we just read.
|
|
func (m *Machine) peerMsgFlags() msgFlags {
|
|
idx := m.hs.MessageIndex() - 1
|
|
if idx >= 0 && idx < len(m.msgs) {
|
|
return m.msgs[idx]
|
|
}
|
|
return msgFlags{}
|
|
}
|
|
|
|
// Initiate produces the first handshake message. Only valid for initiators,
|
|
// and must be called exactly once before ProcessPacket.
|
|
//
|
|
// out is a destination buffer the message is appended to and returned. Pass
|
|
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
|
|
// buf[:0]) with sufficient capacity to avoid allocation.
|
|
//
|
|
// An error return may not indicate a fatal condition, check Failed() to
|
|
// determine if the Machine can still be used.
|
|
func (m *Machine) Initiate(out []byte) ([]byte, error) {
|
|
if m.failed {
|
|
return nil, ErrMachineFailed
|
|
}
|
|
if !m.result.Initiator {
|
|
m.failed = true
|
|
return nil, ErrInitiateOnResponder
|
|
}
|
|
if m.hs.MessageIndex() != 0 {
|
|
m.failed = true
|
|
return nil, ErrInitiateAlreadyCalled
|
|
}
|
|
|
|
// At MessageIndex=0 with RemoteIndex still zero, buildResponse produces
|
|
// header counter 1 and remote index 0, which is what the initial message needs.
|
|
out, _, _, err := m.buildResponse(out)
|
|
if err != nil {
|
|
m.failed = true
|
|
return nil, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// ProcessPacket handles an incoming handshake message. It advances the Noise
|
|
// state, validates the peer certificate via the verifier, and optionally
|
|
// produces a response.
|
|
//
|
|
// out is a destination buffer the response is appended to and returned. Pass
|
|
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
|
|
// buf[:0]) with sufficient capacity to avoid allocation. The returned slice
|
|
// is nil when no outgoing message is produced (handshake complete on this
|
|
// side, or final message of a multi-message pattern).
|
|
//
|
|
// Returns a non-nil Result when the handshake is complete.
|
|
// An error return may not indicate a fatal condition, check Failed() to
|
|
// determine if the Machine can still be used.
|
|
func (m *Machine) ProcessPacket(out, packet []byte) ([]byte, *Result, error) {
|
|
if m.failed {
|
|
return nil, nil, ErrMachineFailed
|
|
}
|
|
if len(packet) < header.Len {
|
|
return nil, nil, ErrPacketTooShort
|
|
}
|
|
// Reject packets whose subtype doesn't match the one this Machine was
|
|
// built for. A pending handshake that suddenly receives a different
|
|
// subtype on its index is either a stray packet that matched by chance
|
|
// or a peer protocol violation; drop it without failing the Machine so
|
|
// the legitimate retransmit can still complete.
|
|
if header.MessageSubType(packet[1]) != m.subtype {
|
|
return nil, nil, ErrSubtypeMismatch
|
|
}
|
|
if m.result.Initiator && m.hs.MessageIndex() == 0 {
|
|
m.failed = true
|
|
return nil, nil, ErrInitiateNotCalled
|
|
}
|
|
|
|
// The (eKey, dKey) ordering here is correct for IX, where the initiator
|
|
// completes the handshake by reading the responder's stage-2 message.
|
|
// noise returns (cs1, cs2) where cs1 is the initiator->responder cipher.
|
|
// For 3-message patterns where a responder finishes by reading the final
|
|
// message, this ordering would be wrong; revisit when XX/pqIX lands.
|
|
msg, eKey, dKey, err := m.hs.ReadMessage(nil, packet[header.Len:])
|
|
if err != nil {
|
|
// Noise ReadMessage failed. The noise library checkpoints and rolls back
|
|
// on failure, so the Machine is still alive. The caller can retry with
|
|
// a different packet.
|
|
return nil, nil, fmt.Errorf("noise ReadMessage: %w", err)
|
|
}
|
|
|
|
// From here on, noise state has advanced. Any error is fatal.
|
|
flags := m.peerMsgFlags()
|
|
|
|
if err := m.processPayload(msg, flags); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// If ReadMessage derived keys, the handshake is complete. Noise should
|
|
// always produce both keys together; asymmetry is a protocol invariant
|
|
// violation.
|
|
if eKey != nil || dKey != nil {
|
|
if eKey == nil || dKey == nil {
|
|
m.failed = true
|
|
return nil, nil, ErrAsymmetricCipherKeys
|
|
}
|
|
if err := m.requireComplete(); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return nil, m.completed(eKey, dKey), nil
|
|
}
|
|
|
|
// ReadMessage didn't complete, produce the next outgoing message
|
|
out, dk, ek, err := m.buildResponse(out)
|
|
if err != nil {
|
|
m.failed = true
|
|
return nil, nil, err
|
|
}
|
|
|
|
if ek != nil || dk != nil {
|
|
if ek == nil || dk == nil {
|
|
m.failed = true
|
|
return nil, nil, ErrAsymmetricCipherKeys
|
|
}
|
|
if err := m.requireComplete(); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return out, m.completed(ek, dk), nil
|
|
}
|
|
|
|
return out, nil, nil
|
|
}
|
|
|
|
func (m *Machine) completed(eKey, dKey *noise.CipherState) *Result {
|
|
m.result.EKey = eKey
|
|
m.result.DKey = dKey
|
|
m.result.MessageIndex = uint64(m.hs.MessageIndex())
|
|
return m.result
|
|
}
|
|
|
|
func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
|
|
if len(msg) == 0 {
|
|
if flags.expectsPayload || flags.expectsCert {
|
|
m.failed = true
|
|
return ErrMissingContent
|
|
}
|
|
return nil
|
|
}
|
|
|
|
payload, err := UnmarshalPayload(msg)
|
|
if err != nil {
|
|
m.failed = true
|
|
return fmt.Errorf("unmarshal handshake: %w", err)
|
|
}
|
|
|
|
// Assert the payload contains exactly what we expect
|
|
hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0
|
|
if hasPayloadData != flags.expectsPayload {
|
|
m.failed = true
|
|
return ErrUnexpectedContent
|
|
}
|
|
|
|
hasCertData := len(payload.Cert) > 0
|
|
if hasCertData != flags.expectsCert {
|
|
m.failed = true
|
|
return ErrUnexpectedContent
|
|
}
|
|
|
|
// Process payload
|
|
if flags.expectsPayload {
|
|
if m.result.Initiator {
|
|
m.result.RemoteIndex = payload.ResponderIndex
|
|
} else {
|
|
m.result.RemoteIndex = payload.InitiatorIndex
|
|
}
|
|
m.result.HandshakeTime = payload.Time
|
|
m.payloadSet = true
|
|
}
|
|
|
|
// Process certificate
|
|
if flags.expectsCert {
|
|
if err := m.validateCert(payload); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Machine) validateCert(payload Payload) error {
|
|
cred := m.getCred(m.myVersion)
|
|
if cred == nil {
|
|
m.failed = true
|
|
return fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
|
|
}
|
|
rc, err := cert.Recombine(
|
|
cert.Version(payload.CertVersion),
|
|
payload.Cert,
|
|
m.hs.PeerStatic(),
|
|
cred.Cert.Curve(),
|
|
)
|
|
if err != nil {
|
|
m.failed = true
|
|
return fmt.Errorf("recombine cert: %w", err)
|
|
}
|
|
|
|
if !bytes.Equal(rc.PublicKey(), m.hs.PeerStatic()) {
|
|
m.failed = true
|
|
return ErrPublicKeyMismatch
|
|
}
|
|
|
|
// Version negotiation, if the peer sent a different version and we have it, switch
|
|
if rc.Version() != m.myVersion {
|
|
if m.getCred(rc.Version()) != nil {
|
|
m.myVersion = rc.Version()
|
|
}
|
|
}
|
|
|
|
verified, err := m.verifier(rc)
|
|
if err != nil {
|
|
m.failed = true
|
|
return fmt.Errorf("verify cert: %w", err)
|
|
}
|
|
|
|
m.result.RemoteCert = verified
|
|
m.remoteCertSet = true
|
|
return nil
|
|
}
|
|
|
|
func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) {
|
|
if !flags.expectsPayload && !flags.expectsCert {
|
|
return nil, nil
|
|
}
|
|
|
|
var p Payload
|
|
if flags.expectsPayload {
|
|
if !m.indexAllocated {
|
|
index, err := m.allocIndex()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %w", ErrIndexAllocation, err)
|
|
}
|
|
m.result.LocalIndex = index
|
|
m.indexAllocated = true
|
|
}
|
|
|
|
if m.result.Initiator {
|
|
p.InitiatorIndex = m.result.LocalIndex
|
|
} else {
|
|
p.ResponderIndex = m.result.LocalIndex
|
|
p.InitiatorIndex = m.result.RemoteIndex
|
|
}
|
|
p.Time = uint64(time.Now().UnixNano())
|
|
}
|
|
if flags.expectsCert {
|
|
cred := m.getCred(m.myVersion)
|
|
if cred == nil {
|
|
return nil, fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
|
|
}
|
|
p.Cert = cred.Bytes
|
|
p.CertVersion = uint32(cred.Cert.Version())
|
|
m.result.MyCert = cred.Cert
|
|
}
|
|
|
|
return MarshalPayload(nil, p), nil
|
|
}
|
|
|
|
func (m *Machine) buildResponse(out []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
|
|
flags := m.myMsgFlags()
|
|
hsBytes, err := m.marshalOutgoing(flags)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
// Extend out by header.Len to make room for the header. slices.Grow is a
|
|
// no-op when the cap is already sufficient (the zero-copy case where the
|
|
// caller passed a pre-sized buffer). header.Encode overwrites the new
|
|
// bytes, so they don't need to be zeroed.
|
|
start := len(out)
|
|
out = slices.Grow(out, header.Len)[:start+header.Len]
|
|
header.Encode(
|
|
out[start:],
|
|
header.Version, header.Handshake, m.subtype,
|
|
m.result.RemoteIndex,
|
|
uint64(m.hs.MessageIndex()+1),
|
|
)
|
|
|
|
// noise.WriteMessage appends the encrypted handshake message to out,
|
|
// reusing capacity when present.
|
|
//
|
|
// The (dKey, eKey) ordering here is correct for IX, where the responder
|
|
// completes the handshake by writing the stage-2 message. noise returns
|
|
// (cs1, cs2) where cs1 is the initiator->responder cipher (which is the
|
|
// responder's decrypt key). For 3-message patterns where an initiator
|
|
// finishes by writing the final message, this ordering would be wrong;
|
|
// revisit when XX/pqIX lands.
|
|
out, dKey, eKey, err := m.hs.WriteMessage(out, hsBytes)
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("noise WriteMessage: %w", err)
|
|
}
|
|
|
|
return out, dKey, eKey, nil
|
|
}
|