diff --git a/firewall_test.go b/firewall_test.go index 4df6eadd..dc863319 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -1042,7 +1042,7 @@ func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition c := &dummyCert{} - cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) + cs, err := newCertState(l, cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) require.NoError(t, err) conf := config.NewC(l) diff --git a/pki.go b/pki.go index 19869d58..5744e5af 100644 --- a/pki.go +++ b/pki.go @@ -91,7 +91,7 @@ func (p *PKI) reload(c *config.C, initial bool) error { } func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { - newState, err := newCertStateFromConfig(c) + newState, err := newCertStateFromConfig(c, p.l) if err != nil { return util.NewContextualError("Could not load client cert", nil, err) } @@ -260,7 +260,7 @@ func (cs *CertState) MarshalJSON() ([]byte, error) { return json.Marshal(msg) } -func newCertStateFromConfig(c *config.C) (*CertState, error) { +func newCertStateFromConfig(c *config.C, l *logrus.Logger) (*CertState, error) { var err error privPathOrPEM := c.GetString("pki.key", "") @@ -344,10 +344,33 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion) } - return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey) + return newCertState(l, initiatingVersion, v1, v2, isPkcs11, curve, rawKey) } -func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { +func compareUnsafeNetworksAcrossCertVersions(v1, v2 cert.Certificate) error { + if v1 == nil || v2 == nil { + return nil //can't be a problem if we don't have one of the kinds of cert + } + + v4UnsafeNets := 0 + for _, n := range v2.UnsafeNetworks() { + if n.Addr().Is6() { + continue // V1 certs can't have IPv6 unsafe networks + } else { + v4UnsafeNets++ + } + if !slices.Contains(v1.UnsafeNetworks(), n) { + return errors.New("UnsafeNetworks mismatch") + } + } + if len(v1.UnsafeNetworks()) != v4UnsafeNets { + return errors.New("UnsafeNetworks mismatch") + } + + return nil +} + +func newCertState(l *logrus.Logger, dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { cs := CertState{ privateKey: privateKey, pkcs11Backed: pkcs11backed, @@ -370,6 +393,12 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p } cs.initiatingVersion = dv + + warn := compareUnsafeNetworksAcrossCertVersions(v1, v2) + if warn != nil { + l.WithFields(m{"UnsafeNetworksV1": v1.UnsafeNetworks(), "UnsafeNetworksV2": v2.UnsafeNetworks()}). + Warning("the IPv4 UnsafeNetworks assigned in the V1 certificate do not match the ones in V2") + } } if v1 != nil {