From a5b84c58e6d6509ae49effd0391d189b96263ec2 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 23 Jan 2026 14:01:10 -0600 Subject: [PATCH] plumb snataddr a little better --- firewall.go | 18 +++----- firewall_test.go | 108 +++++++++++++++++++++++------------------------ interface.go | 3 +- main.go | 3 +- 4 files changed, 63 insertions(+), 69 deletions(-) diff --git a/firewall.go b/firewall.go index 41a9e6e4..6e4461b5 100644 --- a/firewall.go +++ b/firewall.go @@ -151,7 +151,7 @@ type firewallLocalCIDR struct { // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. // The certificate provided should be the highest version loaded in memory. -func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { +func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate, snatAddr netip.Addr) *Firewall { //TODO: error on 0 duration var tmin, tmax time.Duration @@ -185,9 +185,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D hasUnsafeNetworks = true } - snatAddr := netip.Addr{} - if hasUnsafeNetworks && !hasV4Networks { - snatAddr = netip.MustParseAddr("169.254.55.96") //todo this needs to come from the config, or perhaps the tun + if !hasUnsafeNetworks || hasV4Networks { + snatAddr = netip.Addr{} //disable using the special snat address if it doesn't make sense to use it } return &Firewall{ @@ -219,7 +218,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } } -func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C, snatAddr netip.Addr) (*Firewall, error) { certificate := cs.getCertificate(cert.Version2) if certificate == nil { certificate = cs.getCertificate(cert.Version1) @@ -229,14 +228,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew panic("No certificate available to reconfigure the firewall") } - fw := NewFirewall( - l, - c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), - c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), - c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), - certificate, - //TODO: max_connections - ) + fw := NewFirewall(l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), certificate, snatAddr) fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) diff --git a/firewall_test.go b/firewall_test.go index 4a98acfc..6bdf7ab9 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -21,7 +21,7 @@ import ( func TestNewFirewall(t *testing.T) { l := test.NewLogger() c := &dummyCert{} - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) conntrack := fw.Conntrack assert.NotNil(t, conntrack) assert.NotNil(t, conntrack.Conns) @@ -36,23 +36,23 @@ func TestNewFirewall(t *testing.T) { assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c) + fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c, netip.Addr{}) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c) + fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c, netip.Addr{}) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c) + fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c, netip.Addr{}) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c) + fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c, netip.Addr{}) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) - fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c) + fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c, netip.Addr{}) assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration) assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen) } @@ -63,7 +63,7 @@ func TestFirewall_AddRule(t *testing.T) { l.SetOutput(ob) c := &dummyCert{} - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) assert.NotNil(t, fw.InRules) assert.NotNil(t, fw.OutRules) @@ -79,55 +79,55 @@ func TestFirewall_AddRule(t *testing.T) { assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", "")) assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6) assert.True(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) anyIp, err := netip.ParsePrefix("0.0.0.0/0") require.NoError(t, err) @@ -138,7 +138,7 @@ func TestFirewall_AddRule(t *testing.T) { table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9")) assert.False(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) anyIp6, err := netip.ParsePrefix("::/0") require.NoError(t, err) @@ -149,28 +149,28 @@ func TestFirewall_AddRule(t *testing.T) { table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1")) assert.False(t, ok) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", "")) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", "")) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9"))) assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1"))) - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{}) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", "")) require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", "")) } @@ -207,7 +207,7 @@ func TestFirewall_Drop(t *testing.T) { } h.buildNetworks(myVpnNetworksTable, &c) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -226,27 +226,27 @@ func TestFirewall_Drop(t *testing.T) { p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) @@ -286,7 +286,7 @@ func TestFirewall_DropV6(t *testing.T) { } h.buildNetworks(myVpnNetworksTable, &c) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -305,27 +305,27 @@ func TestFirewall_DropV6(t *testing.T) { p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) @@ -531,7 +531,7 @@ func TestFirewall_Drop2(t *testing.T) { } h1.buildNetworks(myVpnNetworksTable, c1.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -611,7 +611,7 @@ func TestFirewall_Drop3(t *testing.T) { } h3.buildNetworks(myVpnNetworksTable, c3.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha")) cp := cert.NewCAPool() @@ -626,7 +626,7 @@ func TestFirewall_Drop3(t *testing.T) { assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule) // Test a remote address match - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil)) } @@ -663,7 +663,7 @@ func TestFirewall_Drop3V6(t *testing.T) { h.buildNetworks(myVpnNetworksTable, c.Certificate) // Test a remote address match - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) cp := cert.NewCAPool() require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil)) @@ -703,7 +703,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { } h.buildNetworks(myVpnNetworksTable, c.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -716,7 +716,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw := fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -725,7 +725,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil)) oldFw = fw - fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -770,7 +770,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { } t.Run("ICMP allowed", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", "")) t.Run("zero ports", func(t *testing.T) { p := templ.Copy() @@ -800,7 +800,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { }) t.Run("Any proto, some ports allowed", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", "")) t.Run("zero ports, still blocked", func(t *testing.T) { p := templ.Copy() @@ -842,7 +842,7 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { }) }) t.Run("Any proto, any port", func(t *testing.T) { - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) t.Run("zero ports, allowed", func(t *testing.T) { resetConntrack(fw) @@ -907,7 +907,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { } h1.buildNetworks(myVpnNetworksTable, c1.Certificate) - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "")) cp := cert.NewCAPool() @@ -1046,53 +1046,53 @@ func TestNewFirewallFromConfig(t *testing.T) { conf := config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": "asdf"} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{}) require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } @@ -1394,7 +1394,7 @@ func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { for _, prefix := range c.Networks() { myVpnNetworksTable.Insert(prefix) } - fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{}) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", "")) return testsetup{ diff --git a/interface.go b/interface.go index 9f83d183..53cef974 100644 --- a/interface.go +++ b/interface.go @@ -56,6 +56,7 @@ type Interface struct { inside overlay.Device pki *PKI firewall *Firewall + snatAddr netip.Addr connectionManager *connectionManager handshakeManager *HandshakeManager serveDns bool @@ -337,7 +338,7 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) + fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c, f.firewall.snatAddr) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return diff --git a/main.go b/main.go index 5a37eacd..5098e3d9 100644 --- a/main.go +++ b/main.go @@ -66,7 +66,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - fw, err := NewFirewallFromConfig(l, pki.getCertState(), c) + snatAddr := netip.MustParseAddr("169.254.55.96") //todo get this from tun! + fw, err := NewFirewallFromConfig(l, pki.getCertState(), c, snatAddr) if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) }