This commit is contained in:
maggie44 2026-01-20 16:54:10 +00:00 committed by GitHub
commit 805313d41c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 188 additions and 29 deletions

View file

@ -1,11 +1,13 @@
package cert
import (
"bytes"
"encoding/pem"
"errors"
"fmt"
"io"
"net/netip"
"slices"
"strings"
"time"
)
@ -29,21 +31,55 @@ func NewCAPool() *CAPool {
// If the pool contains any expired certificates, an ErrExpired will be
// returned along with the pool. The caller must handle any such errors.
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs))
}
// NewCAPoolFromPEMReader will create a new CA pool from the provided reader.
// The reader must contain a PEM-encoded set of nebula certificates.
func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) {
pool := NewCAPool()
var err error
buf := make([]byte, 0, 64*1024)
tmp := make([]byte, 32*1024)
var expired bool
for {
caPEMs, err = pool.AddCAFromPEM(caPEMs)
if errors.Is(err, ErrExpired) {
expired = true
err = nil
n, err := r.Read(tmp)
if n > 0 {
buf = append(buf, tmp[:n]...)
for {
var block *pem.Block
block, buf = pem.Decode(buf)
if block == nil {
break
}
c, err := unmarshalCertificateBlock(block)
if err != nil {
return nil, err
}
err = pool.AddCA(c)
if errors.Is(err, ErrExpired) {
expired = true
continue
}
if err != nil {
return nil, err
}
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
}
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
break
}
}
if len(bytes.TrimSpace(buf)) > 0 {
return nil, ErrInvalidPEMBlock
}
if expired {

View file

@ -37,19 +37,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
return nil, r, ErrInvalidPEMBlock
}
var c Certificate
var err error
switch p.Type {
// Implementations must validate the resulting certificate contains valid information
case CertificateBanner:
c, err = unmarshalCertificateV1(p.Bytes, nil)
case CertificateV2Banner:
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
default:
return nil, r, ErrInvalidPEMCertificateBanner
}
c, err := unmarshalCertificateBlock(p)
if err != nil {
return nil, r, err
}
@ -58,6 +46,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
}
// unmarshalCertificateBlock decodes a single PEM block into a certificate.
// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise.
func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) {
switch block.Type {
// Implementations must validate the resulting certificate contains valid information
case CertificateBanner:
return unmarshalCertificateV1(block.Bytes, nil)
case CertificateV2Banner:
return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519)
default:
return nil, ErrInvalidPEMCertificateBanner
}
}
func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())

15
pki.go
View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/netip"
"os"
@ -487,25 +488,25 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
}
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
var rawCA []byte
var err error
caPathOrPEM := c.GetString("pki.ca", "")
if caPathOrPEM == "" {
return nil, errors.New("no pki.ca path or PEM data provided")
}
if strings.Contains(caPathOrPEM, "-----BEGIN") {
rawCA = []byte(caPathOrPEM)
var caReader io.ReadCloser
var err error
if strings.Contains(caPathOrPEM, "-----BEGIN") {
caReader = io.NopCloser(strings.NewReader(caPathOrPEM))
} else {
rawCA, err = os.ReadFile(caPathOrPEM)
caReader, err = os.Open(caPathOrPEM)
if err != nil {
return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
}
}
defer caReader.Close()
caPool, err := cert.NewCAPoolFromPEM(rawCA)
caPool, err := cert.NewCAPoolFromPEMReader(caReader)
if errors.Is(err, cert.ErrExpired) {
var expired int
for _, crt := range caPool.CAs {

120
pki_hup_benchmark_test.go Normal file
View file

@ -0,0 +1,120 @@
package nebula
import (
"bytes"
"fmt"
"net/netip"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/slackhq/nebula/cert"
cert_test "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/require"
)
func BenchmarkReloadConfigWithCAs(b *testing.B) {
prevProcs := runtime.GOMAXPROCS(1)
b.Cleanup(func() { runtime.GOMAXPROCS(prevProcs) })
for _, size := range []int{100, 250, 500, 1000, 5000} {
b.Run(fmt.Sprintf("%dCAs", size), func(b *testing.B) {
l := test.NewLogger()
dir := b.TempDir()
ca, caKey, caBundle := buildCABundle(b, size)
caPath, certPath, keyPath := writePKIFiles(b, dir, ca, caKey, caBundle)
configBody := fmt.Sprintf(`pki:
ca: %s
cert: %s
key: %s
`, caPath, certPath, keyPath)
configPath := filepath.Join(dir, "config.yml")
require.NoError(b, os.WriteFile(configPath, []byte(configBody), 0o600))
c := config.NewC(l)
require.NoError(b, c.Load(dir))
_, err := NewPKIFromConfig(l, c)
require.NoError(b, err)
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
c.ReloadConfig()
}
})
}
}
func buildCABundle(b *testing.B, count int) (cert.Certificate, []byte, []byte) {
b.Helper()
require.GreaterOrEqual(b, count, 1)
before := time.Now().Add(-24 * time.Hour)
after := time.Now().Add(24 * time.Hour)
ca, _, caKey, pem := cert_test.NewTestCaCert(
cert.Version2,
cert.Curve_CURVE25519,
before,
after,
nil,
nil,
nil,
)
buf := bytes.NewBuffer(pem)
for i := 1; i < count; i++ {
_, _, _, extraPEM := cert_test.NewTestCaCert(
cert.Version2,
cert.Curve_CURVE25519,
time.Now(),
time.Now().Add(time.Hour),
nil,
nil,
nil,
)
buf.Write(extraPEM)
}
return ca, caKey, buf.Bytes()
}
func writePKIFiles(b *testing.B, dir string, ca cert.Certificate, caKey []byte, caBundle []byte) (string, string, string) {
b.Helper()
networks := []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}
_, _, keyPEM, certPEM := cert_test.NewTestCert(
cert.Version2,
cert.Curve_CURVE25519,
ca,
caKey,
"reload-benchmark",
time.Now(),
time.Now().Add(time.Hour),
networks,
nil,
nil,
)
caPath := filepath.Join(dir, "ca.pem")
certPath := filepath.Join(dir, "cert.pem")
keyPath := filepath.Join(dir, "key.pem")
require.NoError(b, os.WriteFile(caPath, caBundle, 0o600))
require.NoError(b, os.WriteFile(certPath, certPEM, 0o600))
require.NoError(b, os.WriteFile(keyPath, keyPEM, 0o600))
return caPath, certPath, keyPath
}