Cert interface (#1212)

This commit is contained in:
Nate Brown 2024-10-10 18:00:22 -05:00 committed by GitHub
parent 16eaae306a
commit 08ac65362e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
49 changed files with 2862 additions and 2833 deletions

46
pki.go
View file

@ -16,16 +16,17 @@ import (
type PKI struct {
cs atomic.Pointer[CertState]
caPool atomic.Pointer[cert.NebulaCAPool]
caPool atomic.Pointer[cert.CAPool]
l *logrus.Logger
}
type CertState struct {
Certificate *cert.NebulaCertificate
Certificate cert.Certificate
RawCertificate []byte
RawCertificateNoKey []byte
PublicKey []byte
PrivateKey []byte
pkcs11Backed bool
}
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
@ -49,7 +50,7 @@ func (p *PKI) GetCertState() *CertState {
return p.cs.Load()
}
func (p *PKI) GetCAPool() *cert.NebulaCAPool {
func (p *PKI) GetCAPool() *cert.CAPool {
return p.caPool.Load()
}
@ -84,12 +85,12 @@ func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
// did IP in cert change? if so, don't set
currentCert := p.cs.Load().Certificate
oldIPs := currentCert.Details.Ips
newIPs := cs.Certificate.Details.Ips
oldIPs := currentCert.Networks()
newIPs := cs.Certificate.Networks()
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
return util.NewContextualError(
"IP in new cert was different from old",
m{"new_ip": newIPs[0], "old_ip": oldIPs[0]},
"Networks in new cert was different from old",
m{"new_network": newIPs[0], "old_network": oldIPs[0]},
nil,
)
}
@ -115,29 +116,28 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
return nil
}
func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) {
func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) {
// Marshal the certificate to ensure it is valid
rawCertificate, err := certificate.Marshal()
if err != nil {
return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
}
publicKey := certificate.Details.PublicKey
publicKey := certificate.PublicKey()
cs := &CertState{
RawCertificate: rawCertificate,
Certificate: certificate,
PrivateKey: privateKey,
PublicKey: publicKey,
pkcs11Backed: pkcs11backed,
}
cs.Certificate.Details.PublicKey = nil
rawCertNoKey, err := cs.Certificate.Marshal()
rawCertNoKey, err := cs.Certificate.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
}
cs.RawCertificateNoKey = rawCertNoKey
// put public key back
cs.Certificate.Details.PublicKey = cs.PublicKey
return cs, nil
}
@ -146,7 +146,7 @@ func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPk
if strings.Contains(privPathOrPEM, "-----BEGIN") {
pemPrivateKey = []byte(privPathOrPEM)
privPathOrPEM = "<inline>"
rawKey, _, curve, err = cert.UnmarshalPrivateKey(pemPrivateKey)
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
@ -158,7 +158,7 @@ func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPk
if err != nil {
return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
}
rawKey, _, curve, err = cert.UnmarshalPrivateKey(pemPrivateKey)
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
@ -198,27 +198,27 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
}
}
nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert)
if err != nil {
return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
}
nebulaCert.Pkcs11Backed = isPkcs11
if nebulaCert.Expired(time.Now()) {
return nil, fmt.Errorf("nebula certificate for this host is expired")
}
if len(nebulaCert.Details.Ips) == 0 {
return nil, fmt.Errorf("no IPs encoded in certificate")
if len(nebulaCert.Networks()) == 0 {
return nil, fmt.Errorf("no networks encoded in certificate")
}
if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
}
return newCertState(nebulaCert, rawKey)
return newCertState(nebulaCert, isPkcs11, rawKey)
}
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
var rawCA []byte
var err error
@ -237,11 +237,11 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, er
}
}
caPool, err := cert.NewCAPoolFromBytes(rawCA)
caPool, err := cert.NewCAPoolFromPEM(rawCA)
if errors.Is(err, cert.ErrExpired) {
var expired int
for _, crt := range caPool.CAs {
if crt.Expired(time.Now()) {
if crt.Certificate.Expired(time.Now()) {
expired++
l.WithField("cert", crt).Warn("expired certificate present in CA pool")
}