diff --git a/firewall.go b/firewall.go index 971c156d..a349d2f9 100644 --- a/firewall.go +++ b/firewall.go @@ -417,6 +417,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return nil } +var ErrUnknownNetworkType = errors.New("unknown network type") +var ErrPeerRejected = errors.New("remote IP is not within a subnet that we handle") var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets") var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs") var ErrNoMatchingRule = errors.New("no matching rule in firewall table") @@ -429,18 +431,31 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * return nil } - // Make sure remote address matches nebula certificate - if h.networks != nil { - if !h.networks.Contains(fp.RemoteAddr) { - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrInvalidRemoteIP - } - } else { + // Make sure remote address matches nebula certificate, and determine how to treat it + if h.networks == nil { // Simple case: Certificate has one address and no unsafe networks if h.vpnAddrs[0] != fp.RemoteAddr { f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } + } else { + nwType, ok := h.networks.Lookup(fp.RemoteAddr) + if !ok { + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrInvalidRemoteIP + } + switch nwType { + case NetworkTypeVPN: + break // nothing special + case NetworkTypeVPNPeer: + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrPeerRejected // reject for now, one day this may have different FW rules + case NetworkTypeUnsafe: + break // nothing special, one day this may have different FW rules + default: + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrUnknownNetworkType //should never happen + } } // Make sure we are supposed to be handling this local ip address diff --git a/firewall_test.go b/firewall_test.go index a0cb3c88..f4630273 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "github.com/gaissmai/bart" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" @@ -149,7 +151,8 @@ func TestFirewall_Drop(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) - + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"), @@ -174,7 +177,7 @@ func TestFirewall_Drop(t *testing.T) { }, vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, } - h.buildNetworks(c.networks, c.unsafeNetworks) + h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -226,6 +229,9 @@ func TestFirewall_DropV6(t *testing.T) { ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) + p := firewall.Packet{ LocalAddr: netip.MustParseAddr("fd12::34"), RemoteAddr: netip.MustParseAddr("fd12::34"), @@ -250,7 +256,7 @@ func TestFirewall_DropV6(t *testing.T) { }, vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")}, } - h.buildNetworks(c.networks, c.unsafeNetworks) + h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -453,6 +459,8 @@ func TestFirewall_Drop2(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), @@ -478,7 +486,7 @@ func TestFirewall_Drop2(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) + h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) c1 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -493,7 +501,7 @@ func TestFirewall_Drop2(t *testing.T) { peerCert: &c1, }, } - h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) + h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -510,6 +518,8 @@ func TestFirewall_Drop3(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), @@ -541,7 +551,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) + h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) c2 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -556,7 +566,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks()) + h2.buildNetworks(myVpnNetworksTable, c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks()) c3 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -571,7 +581,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) + h3.buildNetworks(myVpnNetworksTable, c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -597,6 +607,8 @@ func TestFirewall_Drop3V6(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("fd12::34"), @@ -620,7 +632,7 @@ func TestFirewall_Drop3V6(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) + h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) // Test a remote address match fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) @@ -633,6 +645,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), @@ -659,7 +673,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) + h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -696,6 +710,8 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24")) c := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -717,7 +733,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { }, vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()}, } - h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) + h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) @@ -1047,6 +1063,171 @@ func TestFirewall_convertRule(t *testing.T) { assert.Equal(t, "group1", r.Group) } +type testcase struct { + h *HostInfo + p firewall.Packet + c cert.Certificate + err error +} + +func (c *testcase) Test(t *testing.T, fw *Firewall) { + t.Helper() + cp := cert.NewCAPool() + resetConntrack(fw) + err := fw.Drop(c.p, true, c.h, cp, nil) + if c.err == nil { + require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr) + } else { + require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr) + } +} + +func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase { + c1 := dummyCert{ + name: "host1", + networks: theirPrefixes, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: &c1, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: make([]netip.Addr, len(theirPrefixes)), + } + for i := range theirPrefixes { + h.vpnAddrs[i] = theirPrefixes[i].Addr() + } + h.buildNetworks(setup.myVpnNetworksTable, c1.networks, c1.unsafeNetworks) + p := firewall.Packet{ + LocalAddr: setup.c.Networks()[0].Addr(), //todo? + RemoteAddr: theirPrefixes[0].Addr(), + LocalPort: 10, + RemotePort: 90, + Protocol: firewall.ProtoUDP, + Fragment: false, + } + return testcase{ + h: &h, + p: p, + c: &c1, + err: err, + } +} + +type testsetup struct { + c dummyCert + myVpnNetworksTable *bart.Lite + fw *Firewall +} + +func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup { + c := dummyCert{ + name: "me", + networks: myPrefixes, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + return newSetupFromCert(t, l, c) +} + +func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { + myVpnNetworksTable := new(bart.Lite) + for _, prefix := range c.Networks() { + myVpnNetworksTable.Insert(prefix) + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + + return testsetup{ + c: c, + fw: fw, + myVpnNetworksTable: myVpnNetworksTable, + } +} + +func TestFirewall_Drop_EnforceIPMatch(t *testing.T) { + t.Parallel() + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + + myPrefix := netip.MustParsePrefix("1.1.1.1/8") + // for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out + t.Run("allow inbound all matching", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24")) + tc.Test(t, setup.fw) + }) + t.Run("allow inbound local matching", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24")) + tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8") + tc.Test(t, setup.fw) + }) + t.Run("block inbound remote mismatched", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24")) + tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9") + tc.Test(t, setup.fw) + }) + t.Run("Block a vpn peer packet", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24")) + tc.Test(t, setup.fw) + }) + twoPrefixes := []netip.Prefix{ + netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"), + } + t.Run("allow inbound one matching", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, nil, twoPrefixes...) + tc.Test(t, setup.fw) + }) + t.Run("block inbound multimismatch", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...) + tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9") + tc.Test(t, setup.fw) + }) + t.Run("allow inbound 2nd one matching", func(t *testing.T) { + t.Parallel() + setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24")) + tc := buildTestCase(setup2, nil, twoPrefixes...) + tc.p.RemoteAddr = twoPrefixes[1].Addr() + tc.Test(t, setup2.fw) + }) + t.Run("allow inbound unsafe route", func(t *testing.T) { + t.Parallel() + unsafePrefix := netip.MustParsePrefix("192.168.0.0/24") + c := dummyCert{ + name: "me", + networks: []netip.Prefix{myPrefix}, + unsafeNetworks: []netip.Prefix{unsafePrefix}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + unsafeSetup := newSetupFromCert(t, l, c) + tc := buildTestCase(unsafeSetup, nil, twoPrefixes...) + tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3") + tc.err = ErrNoMatchingRule + tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off + require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", "")) + tc.err = nil + tc.Test(t, unsafeSetup.fw) //should pass + }) +} + type addRuleCall struct { incoming bool proto uint8 diff --git a/handshake_ix.go b/handshake_ix.go index 026bfbd4..66a78582 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -332,7 +332,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.SetRemote(addr) - hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) + hostinfo.buildNetworks(f.myVpnNetworksTable, filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { @@ -648,7 +648,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // Build up the radix for the firewall if we have subnets in the cert hostinfo.vpnAddrs = vpnAddrs - hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) + hostinfo.buildNetworks(f.myVpnNetworksTable, filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here f.handshakeManager.Complete(hostinfo, f) diff --git a/hostmap.go b/hostmap.go index 66b4851e..28270ec1 100644 --- a/hostmap.go +++ b/hostmap.go @@ -212,6 +212,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.relayForByIdx[idx] = r } +type NetworkType uint8 + +const ( + NetworkTypeUnknown NetworkType = iota + // NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate + NetworkTypeVPN + // NetworkTypeVPNPeer is a network that does not overlap one of our networks + NetworkTypeVPNPeer + // NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks() + NetworkTypeUnsafe +) + type HostInfo struct { remote netip.AddrPort remotes *RemoteList @@ -225,8 +237,8 @@ type HostInfo struct { // vpn networks but were removed because they are not usable vpnAddrs []netip.Addr - // networks are both all vpn and unsafe networks assigned to this host - networks *bart.Lite + // networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host. + networks *bart.Table[NetworkType] relayState RelayState // HandshakePacket records the packets used to create this hostinfo @@ -730,20 +742,27 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b return false } -func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) { +func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, networks, unsafeNetworks []netip.Prefix) { if len(networks) == 1 && len(unsafeNetworks) == 0 { - // Simple case, no CIDRTree needed - return + if myVpnNetworksTable.Contains(networks[0].Addr()) { + return // Simple case, no CIDRTree needed + } } - i.networks = new(bart.Lite) + i.networks = new(bart.Table[NetworkType]) for _, network := range networks { + var nwType NetworkType + if myVpnNetworksTable.Contains(network.Addr()) { + nwType = NetworkTypeVPN + } else { + nwType = NetworkTypeVPNPeer + } nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) - i.networks.Insert(nprefix) + i.networks.Insert(nprefix, nwType) } for _, network := range unsafeNetworks { - i.networks.Insert(network) + i.networks.Insert(network, NetworkTypeUnsafe) } }