Ensure pubkey coherency when rehydrating a handshake cert (#1566)

* Ensure pubkey coherency when rehydrating a handshake cert
* Include a check during handshakes after cert verification that the noise pubkey matches the cert pubkey.
This commit is contained in:
brad-defined 2026-01-09 09:52:03 -05:00 committed by GitHub
parent 3ec527e42c
commit 2f71d6b22d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 142 additions and 5 deletions

View file

@ -119,6 +119,7 @@ func (cc *CachedCertificate) String() string {
// Recombine will attempt to unmarshal a certificate received in a handshake.
// Handshakes save space by placing the peers public key in a different part of the packet, we have to
// reassemble the actual certificate structure with that in mind.
// Implementations MUST assert the public key is not in the raw certificate bytes if the passed in public key is not empty.
func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certificate, error) {
if publicKey == nil {
return nil, ErrNoPeerStaticKey

View file

@ -426,7 +426,7 @@ func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error)
unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2),
notBefore: time.Unix(rc.Details.NotBefore, 0),
notAfter: time.Unix(rc.Details.NotAfter, 0),
publicKey: make([]byte, len(rc.Details.PublicKey)),
publicKey: nil,
isCA: rc.Details.IsCA,
curve: rc.Details.Curve,
},
@ -437,12 +437,19 @@ func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error)
copy(nc.details.groups, rc.Details.Groups)
nc.details.issuer = hex.EncodeToString(rc.Details.Issuer)
// If a public key is passed in as an argument, the certificate pubkey must be empty
// and the passed-in pubkey copied into the cert.
if len(publicKey) > 0 {
nc.details.publicKey = publicKey
if len(rc.Details.PublicKey) != 0 {
return nil, ErrCertPubkeyPresent
}
nc.details.publicKey = make([]byte, len(publicKey))
copy(nc.details.publicKey, publicKey)
} else {
nc.details.publicKey = make([]byte, len(rc.Details.PublicKey))
copy(nc.details.publicKey, rc.Details.PublicKey)
}
copy(nc.details.publicKey, rc.Details.PublicKey)
var ip netip.Addr
for i, rawIp := range rc.Details.Ips {
if i%2 == 0 {

View file

@ -62,6 +62,62 @@ func TestCertificateV1_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV1_Unmarshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
invalidPubkey := []byte("00000000000000000000000000000000")
nc := certificateV1{
details: detailsV1{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.1/24"),
mustParsePrefixUnmapped("10.1.1.2/16"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.2/24"),
mustParsePrefixUnmapped("9.1.1.3/16"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
publicKey: pubKey,
isCA: false,
issuer: "1234567890abcedfghij1234567890ab",
},
signature: []byte("1234567890abcedfghij1234567890ab"),
}
// This certificate has a pubkey included
certWithPubkey, err := nc.Marshal()
require.NoError(t, err)
// This certificate is missing the pubkey section
certWithoutPubkey, err := nc.MarshalForHandshakes()
require.NoError(t, err)
// Cert has no pubkey and no pubkey passed in must fail to validate
isNil, err := unmarshalCertificateV1(certWithoutPubkey, nil)
require.Error(t, err)
// Cert has different pubkey than one passed in must fail
isNil, err = unmarshalCertificateV1(certWithPubkey, invalidPubkey)
require.Nil(t, isNil)
require.Error(t, err)
// Cert has pubkey and no pubkey argument works ok
_, err = unmarshalCertificateV1(certWithPubkey, nil)
require.NoError(t, err)
// Cert has no pubkey and valid, correctly signed pubkey passed in
nc2, err := unmarshalCertificateV1(certWithoutPubkey, pubKey)
require.NoError(t, err)
assert.Equal(t, pubKey, nc2.PublicKey())
}
func TestCertificateV1_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)

View file

@ -592,7 +592,13 @@ func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certifica
// Maybe grab the public key
var rawPublicKey cryptobyte.String
if len(publicKey) > 0 {
rawPublicKey = publicKey
// If a public key is passed in, then the handshake certificate must
// not have a public key present
if input.PeekASN1Tag(TagCertPublicKey) {
return nil, ErrCertPubkeyPresent
}
rawPublicKey = make(cryptobyte.String, len(publicKey))
copy(rawPublicKey, publicKey)
} else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) {
return nil, ErrBadFormat
}

View file

@ -76,6 +76,58 @@ func TestCertificateV2_Marshal(t *testing.T) {
assert.Equal(t, nc.Groups(), nc2.Groups())
}
func TestCertificateV2_Unmarshal(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)
after := time.Now().Add(time.Second * 60).Round(time.Second)
pubKey := []byte("1234567890abcedfghij1234567890ab")
nc := certificateV2{
details: detailsV2{
name: "testing",
networks: []netip.Prefix{
mustParsePrefixUnmapped("10.1.1.2/16"),
mustParsePrefixUnmapped("10.1.1.1/24"),
},
unsafeNetworks: []netip.Prefix{
mustParsePrefixUnmapped("9.1.1.3/16"),
mustParsePrefixUnmapped("9.1.1.2/24"),
},
groups: []string{"test-group1", "test-group2", "test-group3"},
notBefore: before,
notAfter: after,
isCA: false,
issuer: "1234567890abcdef1234567890abcdef",
},
signature: []byte("1234567890abcdef1234567890abcdef"),
publicKey: pubKey,
}
db, err := nc.details.Marshal()
require.NoError(t, err)
nc.rawDetails = db
certWithPubkey, err := nc.Marshal()
require.NoError(t, err)
//t.Log("Cert size:", len(b))
certWithoutPubkey, err := nc.MarshalForHandshakes()
require.NoError(t, err)
// Cert must not have a pubkey if one is passed in as an argument
_, err = unmarshalCertificateV2(certWithPubkey, pubKey, Curve_CURVE25519)
require.ErrorIs(t, err, ErrCertPubkeyPresent)
// Certs must have pubkeys
_, err = unmarshalCertificateV2(certWithoutPubkey, nil, Curve_CURVE25519)
require.ErrorIs(t, err, ErrBadFormat)
// Ensure proper unmarshal if a pubkey is passed in
nc2, err := unmarshalCertificateV2(certWithoutPubkey, pubKey, Curve_CURVE25519)
require.NoError(t, err)
assert.Equal(t, nc.PublicKey(), nc2.PublicKey())
}
func TestCertificateV2_PublicKeyPem(t *testing.T) {
t.Parallel()
before := time.Now().Add(time.Second * -60).Round(time.Second)

View file

@ -21,6 +21,7 @@ var (
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
ErrCaNotFound = errors.New("could not find ca for the certificate")
ErrUnknownVersion = errors.New("certificate version unrecognized")
ErrCertPubkeyPresent = errors.New("certificate has unexpected pubkey present")
ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block")
ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner")

View file

@ -1,6 +1,7 @@
package nebula
import (
"bytes"
"net/netip"
"time"
@ -166,6 +167,13 @@ func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H)
return
}
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
f.l.WithField("from", via).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
return
}
if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
@ -535,6 +543,12 @@ func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packe
e.Info("Invalid certificate from host")
return true
}
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
f.l.WithField("from", via).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("cert", remoteCert).Info("public key mismatch between certificate and handshake")
return true
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.WithError(err).WithField("from", via).