plumb snataddr a little better

This commit is contained in:
JackDoan 2026-01-23 14:01:10 -06:00
parent 63373071cc
commit a5b84c58e6
4 changed files with 63 additions and 69 deletions

View file

@ -151,7 +151,7 @@ type firewallLocalCIDR struct {
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
// The certificate provided should be the highest version loaded in memory.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate, snatAddr netip.Addr) *Firewall {
//TODO: error on 0 duration
var tmin, tmax time.Duration
@ -185,9 +185,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
hasUnsafeNetworks = true
}
snatAddr := netip.Addr{}
if hasUnsafeNetworks && !hasV4Networks {
snatAddr = netip.MustParseAddr("169.254.55.96") //todo this needs to come from the config, or perhaps the tun
if !hasUnsafeNetworks || hasV4Networks {
snatAddr = netip.Addr{} //disable using the special snat address if it doesn't make sense to use it
}
return &Firewall{
@ -219,7 +218,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}
}
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C, snatAddr netip.Addr) (*Firewall, error) {
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
@ -229,14 +228,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
panic("No certificate available to reconfigure the firewall")
}
fw := NewFirewall(
l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
certificate,
//TODO: max_connections
)
fw := NewFirewall(l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), certificate, snatAddr)
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)

View file

@ -21,7 +21,7 @@ import (
func TestNewFirewall(t *testing.T) {
l := test.NewLogger()
c := &dummyCert{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
conntrack := fw.Conntrack
assert.NotNil(t, conntrack)
assert.NotNil(t, conntrack.Conns)
@ -36,23 +36,23 @@ func TestNewFirewall(t *testing.T) {
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
}
@ -63,7 +63,7 @@ func TestFirewall_AddRule(t *testing.T) {
l.SetOutput(ob)
c := &dummyCert{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules)
@ -79,55 +79,55 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
require.NoError(t, err)
@ -138,7 +138,7 @@ func TestFirewall_AddRule(t *testing.T) {
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err)
@ -149,28 +149,28 @@ func TestFirewall_AddRule(t *testing.T) {
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
}
@ -207,7 +207,7 @@ func TestFirewall_Drop(t *testing.T) {
}
h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
@ -226,27 +226,27 @@ func TestFirewall_Drop(t *testing.T) {
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
@ -286,7 +286,7 @@ func TestFirewall_DropV6(t *testing.T) {
}
h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
@ -305,27 +305,27 @@ func TestFirewall_DropV6(t *testing.T) {
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
@ -531,7 +531,7 @@ func TestFirewall_Drop2(t *testing.T) {
}
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
cp := cert.NewCAPool()
@ -611,7 +611,7 @@ func TestFirewall_Drop3(t *testing.T) {
}
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
cp := cert.NewCAPool()
@ -626,7 +626,7 @@ func TestFirewall_Drop3(t *testing.T) {
assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule)
// Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
}
@ -663,7 +663,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, c.Certificate)
// Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
@ -703,7 +703,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
}
h.buildNetworks(myVpnNetworksTable, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
@ -716,7 +716,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@ -725,7 +725,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
@ -770,7 +770,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
}
t.Run("ICMP allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports", func(t *testing.T) {
p := templ.Copy()
@ -800,7 +800,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
})
t.Run("Any proto, some ports allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, still blocked", func(t *testing.T) {
p := templ.Copy()
@ -842,7 +842,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
})
})
t.Run("Any proto, any port", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, allowed", func(t *testing.T) {
resetConntrack(fw)
@ -907,7 +907,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
}
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
cp := cert.NewCAPool()
@ -1046,53 +1046,53 @@ func TestNewFirewallFromConfig(t *testing.T) {
conf := config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
}
@ -1394,7 +1394,7 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
for _, prefix := range c.Networks() {
myVpnNetworksTable.Insert(prefix)
}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
return testsetup{

View file

@ -56,6 +56,7 @@ type Interface struct {
inside overlay.Device
pki *PKI
firewall *Firewall
snatAddr netip.Addr
connectionManager *connectionManager
handshakeManager *HandshakeManager
serveDns bool
@ -337,7 +338,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
return
}
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c, f.firewall.snatAddr)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
return

View file

@ -66,7 +66,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
}
fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
snatAddr := netip.MustParseAddr("169.254.55.96") //todo get this from tun!
fw, err := NewFirewallFromConfig(l, pki.getCertState(), c, snatAddr)
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
}