Switch most everything to netip in prep for ipv6 in the overlay (#1173)

This commit is contained in:
Nate Brown 2024-07-31 10:18:56 -05:00 committed by GitHub
parent 00458302ca
commit e264a0ff88
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
79 changed files with 1900 additions and 2682 deletions

View file

@ -6,23 +6,23 @@ import (
"errors"
"fmt"
"hash/fnv"
"net"
"net/netip"
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
)
type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error
}
type conn struct {
@ -52,8 +52,8 @@ type Firewall struct {
DefaultTimeout time.Duration //linux: 600s
// Used to ensure we don't emit local packets for ips we don't own
localIps *cidr.Tree4[struct{}]
assignedCIDR *net.IPNet
localIps *bart.Table[struct{}]
assignedCIDR netip.Prefix
hasSubnets bool
rules string
@ -108,7 +108,7 @@ type FirewallRule struct {
Any *firewallLocalCIDR
Hosts map[string]*firewallLocalCIDR
Groups []*firewallGroups
CIDR *cidr.Tree4[*firewallLocalCIDR]
CIDR *bart.Table[*firewallLocalCIDR]
}
type firewallGroups struct {
@ -122,7 +122,7 @@ type firewallPort map[int32]*FirewallCA
type firewallLocalCIDR struct {
Any bool
LocalCIDR *cidr.Tree4[struct{}]
LocalCIDR *bart.Table[struct{}]
}
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
@ -144,20 +144,28 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
max = defaultTimeout
}
localIps := cidr.NewTree4[struct{}]()
var assignedCIDR *net.IPNet
localIps := new(bart.Table[struct{}])
var assignedCIDR netip.Prefix
var assignedSet bool
for _, ip := range c.Details.Ips {
ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}
localIps.AddCIDR(ipNet, struct{}{})
//TODO: IPV6-WORK the unmap is a bit unfortunate
nip, _ := netip.AddrFromSlice(ip.IP)
nip = nip.Unmap()
nprefix := netip.PrefixFrom(nip, nip.BitLen())
localIps.Insert(nprefix, struct{}{})
if assignedCIDR == nil {
if !assignedSet {
// Only grabbing the first one in the cert since any more than that currently has undefined behavior
assignedCIDR = ipNet
assignedCIDR = nprefix
assignedSet = true
}
}
for _, n := range c.Details.Subnets {
localIps.AddCIDR(n, struct{}{})
nip, _ := netip.AddrFromSlice(n.IP)
ones, _ := n.Mask.Size()
nip = nip.Unmap()
localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{})
}
return &Firewall{
@ -237,15 +245,15 @@ func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *conf
}
// AddRule properly creates the in memory rule structure for a firewall table.
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
// Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
// https://github.com/golang/go/issues/14131
sIp := ""
if ip != nil {
if ip.IsValid() {
sIp = ip.String()
}
lIp := ""
if localIp != nil {
if localIp.IsValid() {
lIp = localIp.String()
}
@ -382,17 +390,17 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
}
var cidr *net.IPNet
var cidr netip.Prefix
if r.Cidr != "" {
_, cidr, err = net.ParseCIDR(r.Cidr)
cidr, err = netip.ParsePrefix(r.Cidr)
if err != nil {
return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
}
}
var localCidr *net.IPNet
var localCidr netip.Prefix
if r.LocalCidr != "" {
_, localCidr, err = net.ParseCIDR(r.LocalCidr)
localCidr, err = netip.ParsePrefix(r.LocalCidr)
if err != nil {
return fmt.Errorf("%s rule #%v; local_cidr did not parse; %s", table, i, err)
}
@ -421,7 +429,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
// Make sure remote address matches nebula certificate
if remoteCidr := h.remoteCidr; remoteCidr != nil {
ok, _ := remoteCidr.Contains(fp.RemoteIP)
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
_, ok := remoteCidr.Lookup(fp.RemoteIP)
if !ok {
f.metrics(incoming).droppedRemoteIP.Inc(1)
return ErrInvalidRemoteIP
@ -435,7 +444,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
}
// Make sure we are supposed to be handling this local ip address
ok, _ := f.localIps.Contains(fp.LocalIP)
//TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different
_, ok := f.localIps.Lookup(fp.LocalIP)
if !ok {
f.metrics(incoming).droppedLocalIP.Inc(1)
return ErrInvalidLocalIP
@ -589,7 +599,6 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
// Caller must own the connMutex lock!
func (f *Firewall) evict(p firewall.Packet) {
//TODO: report a stat if the tcp rtt tracking was never resolved?
// Are we still tracking this conn?
conntrack := f.Conntrack
t, ok := conntrack.Conns[p]
@ -633,7 +642,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
return false
}
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error {
if startPort > endPort {
return fmt.Errorf("start port was lower than end port")
}
@ -677,12 +686,12 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
return fp[firewall.PortAny].match(p, c, caPool)
}
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp netip.Prefix, caName, caSha string) error {
fr := func() *FirewallRule {
return &FirewallRule{
Hosts: make(map[string]*firewallLocalCIDR),
Groups: make([]*firewallGroups, 0),
CIDR: cidr.NewTree4[*firewallLocalCIDR](),
CIDR: new(bart.Table[*firewallLocalCIDR]),
}
}
@ -740,10 +749,10 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
return fc.CANames[s.Details.Name].match(p, c)
}
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error {
func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error {
flc := func() *firewallLocalCIDR {
return &firewallLocalCIDR{
LocalCIDR: cidr.NewTree4[struct{}](),
LocalCIDR: new(bart.Table[struct{}]),
}
}
@ -780,8 +789,8 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
fr.Hosts[host] = nlc
}
if ip != nil {
_, nlc := fr.CIDR.GetCIDR(ip)
if ip.IsValid() {
nlc, _ := fr.CIDR.Get(ip)
if nlc == nil {
nlc = flc()
}
@ -789,14 +798,14 @@ func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *n
if err != nil {
return err
}
fr.CIDR.AddCIDR(ip, nlc)
fr.CIDR.Insert(ip, nlc)
}
return nil
}
func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
if len(groups) == 0 && host == "" && ip == nil {
func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) bool {
if len(groups) == 0 && host == "" && !ip.IsValid() {
return true
}
@ -810,7 +819,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool
return true
}
if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) {
if ip.IsValid() && ip.Bits() == 0 {
return true
}
@ -853,24 +862,31 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
}
}
return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool {
return flc.match(p, c)
matched := false
prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen())
fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool {
if prefix.Contains(p.RemoteIP) && val.match(p, c) {
matched = true
return false
}
return true
})
return matched
}
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error {
if localIp == nil {
func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error {
if !localIp.IsValid() {
if !f.hasSubnets || f.defaultLocalCIDRAny {
flc.Any = true
return nil
}
localIp = f.assignedCIDR
} else if localIp.Contains(net.IPv4(0, 0, 0, 0)) {
} else if localIp.Bits() == 0 {
flc.Any = true
}
flc.LocalCIDR.AddCIDR(localIp, struct{}{})
flc.LocalCIDR.Insert(localIp, struct{}{})
return nil
}
@ -883,7 +899,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate
return true
}
ok, _ := flc.LocalCIDR.Contains(p.LocalIP)
_, ok := flc.LocalCIDR.Lookup(p.LocalIP)
return ok
}