diff --git a/cert/ca_pool.go b/cert/ca_pool.go index 2bf480f2..5439b06a 100644 --- a/cert/ca_pool.go +++ b/cert/ca_pool.go @@ -1,11 +1,13 @@ package cert import ( + "bytes" + "encoding/pem" "errors" "fmt" + "io" "net/netip" "slices" - "strings" "time" ) @@ -29,21 +31,55 @@ func NewCAPool() *CAPool { // If the pool contains any expired certificates, an ErrExpired will be // returned along with the pool. The caller must handle any such errors. func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) { + return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs)) +} + +// NewCAPoolFromPEMReader will create a new CA pool from the provided reader. +// The reader must contain a PEM-encoded set of nebula certificates. +func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) { pool := NewCAPool() - var err error + buf := make([]byte, 0, 64*1024) + tmp := make([]byte, 32*1024) var expired bool + for { - caPEMs, err = pool.AddCAFromPEM(caPEMs) - if errors.Is(err, ErrExpired) { - expired = true - err = nil + n, err := r.Read(tmp) + if n > 0 { + buf = append(buf, tmp[:n]...) + + for { + var block *pem.Block + block, buf = pem.Decode(buf) + if block == nil { + break + } + + c, err := unmarshalCertificateBlock(block) + if err != nil { + return nil, err + } + + err = pool.AddCA(c) + if errors.Is(err, ErrExpired) { + expired = true + continue + } + if err != nil { + return nil, err + } + } + } + + if errors.Is(err, io.EOF) { + break } if err != nil { return nil, err } - if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { - break - } + } + + if len(bytes.TrimSpace(buf)) > 0 { + return nil, ErrInvalidPEMBlock } if expired { diff --git a/cert/pem.go b/cert/pem.go index 8942c23a..9a312769 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -37,19 +37,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { return nil, r, ErrInvalidPEMBlock } - var c Certificate - var err error - - switch p.Type { - // Implementations must validate the resulting certificate contains valid information - case CertificateBanner: - c, err = unmarshalCertificateV1(p.Bytes, nil) - case CertificateV2Banner: - c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519) - default: - return nil, r, ErrInvalidPEMCertificateBanner - } - + c, err := unmarshalCertificateBlock(p) if err != nil { return nil, r, err } @@ -58,6 +46,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { } +// unmarshalCertificateBlock decodes a single PEM block into a certificate. +// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise. +func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) { + switch block.Type { + // Implementations must validate the resulting certificate contains valid information + case CertificateBanner: + return unmarshalCertificateV1(block.Bytes, nil) + case CertificateV2Banner: + return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519) + default: + return nil, ErrInvalidPEMCertificateBanner + } +} + func marshalCertPublicKeyToPEM(c Certificate) []byte { if c.IsCA() { return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey()) diff --git a/pki.go b/pki.go index 19869d58..0639fd3d 100644 --- a/pki.go +++ b/pki.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "net/netip" "os" @@ -487,25 +488,25 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) { } func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { - var rawCA []byte - var err error - caPathOrPEM := c.GetString("pki.ca", "") if caPathOrPEM == "" { return nil, errors.New("no pki.ca path or PEM data provided") } - if strings.Contains(caPathOrPEM, "-----BEGIN") { - rawCA = []byte(caPathOrPEM) + var caReader io.ReadCloser + var err error + if strings.Contains(caPathOrPEM, "-----BEGIN") { + caReader = io.NopCloser(strings.NewReader(caPathOrPEM)) } else { - rawCA, err = os.ReadFile(caPathOrPEM) + caReader, err = os.Open(caPathOrPEM) if err != nil { return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) } } + defer caReader.Close() - caPool, err := cert.NewCAPoolFromPEM(rawCA) + caPool, err := cert.NewCAPoolFromPEMReader(caReader) if errors.Is(err, cert.ErrExpired) { var expired int for _, crt := range caPool.CAs { diff --git a/pki_hup_benchmark_test.go b/pki_hup_benchmark_test.go new file mode 100644 index 00000000..3a201070 --- /dev/null +++ b/pki_hup_benchmark_test.go @@ -0,0 +1,120 @@ +package nebula + +import ( + "bytes" + "fmt" + "net/netip" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/slackhq/nebula/cert" + cert_test "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/require" +) + +func BenchmarkReloadConfigWithCAs(b *testing.B) { + prevProcs := runtime.GOMAXPROCS(1) + b.Cleanup(func() { runtime.GOMAXPROCS(prevProcs) }) + + for _, size := range []int{100, 250, 500, 1000, 5000} { + b.Run(fmt.Sprintf("%dCAs", size), func(b *testing.B) { + l := test.NewLogger() + dir := b.TempDir() + + ca, caKey, caBundle := buildCABundle(b, size) + caPath, certPath, keyPath := writePKIFiles(b, dir, ca, caKey, caBundle) + + configBody := fmt.Sprintf(`pki: + ca: %s + cert: %s + key: %s +`, caPath, certPath, keyPath) + + configPath := filepath.Join(dir, "config.yml") + require.NoError(b, os.WriteFile(configPath, []byte(configBody), 0o600)) + + c := config.NewC(l) + require.NoError(b, c.Load(dir)) + + _, err := NewPKIFromConfig(l, c) + require.NoError(b, err) + + b.ReportAllocs() + b.ResetTimer() + + for b.Loop() { + c.ReloadConfig() + } + }) + } +} + +func buildCABundle(b *testing.B, count int) (cert.Certificate, []byte, []byte) { + b.Helper() + require.GreaterOrEqual(b, count, 1) + + before := time.Now().Add(-24 * time.Hour) + after := time.Now().Add(24 * time.Hour) + + ca, _, caKey, pem := cert_test.NewTestCaCert( + cert.Version2, + cert.Curve_CURVE25519, + before, + after, + nil, + nil, + nil, + ) + + buf := bytes.NewBuffer(pem) + + for i := 1; i < count; i++ { + _, _, _, extraPEM := cert_test.NewTestCaCert( + cert.Version2, + cert.Curve_CURVE25519, + time.Now(), + time.Now().Add(time.Hour), + nil, + nil, + nil, + ) + + buf.Write(extraPEM) + } + + return ca, caKey, buf.Bytes() +} + +func writePKIFiles(b *testing.B, dir string, ca cert.Certificate, caKey []byte, caBundle []byte) (string, string, string) { + b.Helper() + + networks := []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")} + + _, _, keyPEM, certPEM := cert_test.NewTestCert( + cert.Version2, + cert.Curve_CURVE25519, + ca, + caKey, + "reload-benchmark", + time.Now(), + time.Now().Add(time.Hour), + networks, + nil, + nil, + ) + + caPath := filepath.Join(dir, "ca.pem") + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + + require.NoError(b, os.WriteFile(caPath, caBundle, 0o600)) + require.NoError(b, os.WriteFile(certPath, certPEM, 0o600)) + require.NoError(b, os.WriteFile(keyPath, keyPEM, 0o600)) + + return caPath, certPath, keyPath +}