diff --git a/connection_manager.go b/connection_manager.go index e7fc04cd..ee6d1eaf 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -11,7 +11,6 @@ import ( "sync/atomic" "time" - "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" @@ -45,19 +44,16 @@ type connectionManager struct { inactivityTimeout atomic.Int64 dropInactive atomic.Bool - metricsTxPunchy metrics.Counter - l *slog.Logger } func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { cm := &connectionManager{ - hostMap: hm, - l: l, - punchy: p, - relayUsed: make(map[uint32]struct{}), - relayUsedLock: &sync.RWMutex{}, - metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), + hostMap: hm, + l: l, + punchy: p, + relayUsed: make(map[uint32]struct{}), + relayUsedLock: &sync.RWMutex{}, } cm.reload(c, true) @@ -369,7 +365,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim if !outTraffic { // Send a punch packet to keep the NAT state alive - cm.sendPunch(hostinfo) + cm.punchy.SendPunch(hostinfo) } return decision, hostinfo, primary @@ -400,17 +396,16 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. // Just maintain NAT state if configured to do so. - cm.sendPunch(hostinfo) + cm.punchy.SendPunch(hostinfo) cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval) return doNothing, nil, nil } - if cm.punchy.GetTargetEverything() { - // This is similar to the old punchy behavior with a slight optimization. - // We aren't receiving traffic but we are sending it, punch on all known - // ips in case we need to re-prime NAT state - cm.sendPunch(hostinfo) - } + // We aren't receiving traffic but we are sending it. The outbound + // traffic itself refreshes the primary remote's NAT state; this + // fans out to non-primary remotes, but only if target_all_remotes + // is configured. + cm.punchy.SendPunchToAll(hostinfo) if cm.l.Enabled(context.Background(), slog.LevelDebug) { hostinfo.logger(cm.l).Debug("Tunnel status", @@ -512,31 +507,6 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI } } -func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { - if !cm.punchy.GetPunch() { - // Punching is disabled - return - } - - if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) { - // Do not punch to lighthouses, we assume our lighthouse update interval is good enough. - // In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse - // would lose the ability to notify us and punchy.respond would become unreliable. - return - } - - if cm.punchy.GetTargetEverything() { - hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { - cm.metricsTxPunchy.Inc(1) - cm.intf.outside.WriteTo([]byte{1}, addr) - }) - - } else if hostinfo.remote.IsValid() { - cm.metricsTxPunchy.Inc(1) - cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote) - } -} - func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { cs := cm.intf.pki.getCertState() curCrt := hostinfo.ConnectionState.myCert diff --git a/connection_manager_test.go b/connection_manager_test.go index 7dc08a45..e167e5f2 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -64,7 +64,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { // Create manager conf := config.NewC(test.NewLogger()) - punchy := NewPunchyFromConfig(test.NewLogger(), conf) + punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") @@ -146,7 +146,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // Create manager conf := config.NewC(test.NewLogger()) - punchy := NewPunchyFromConfig(test.NewLogger(), conf) + punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") @@ -233,7 +233,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { conf.Settings["tunnels"] = map[string]any{ "drop_inactive": true, } - punchy := NewPunchyFromConfig(test.NewLogger(), conf) + punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) assert.True(t, nc.dropInactive.Load()) nc.intf = ifce @@ -358,7 +358,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { // Create manager conf := config.NewC(test.NewLogger()) - punchy := NewPunchyFromConfig(test.NewLogger(), conf) + punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce ifce.connectionManager = nc diff --git a/examples/config.yml b/examples/config.yml index f5752ae4..ac4810e6 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -163,17 +163,21 @@ listen: punchy: # Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings + # This setting is reloadable. punch: true # respond means that a node you are trying to reach will connect back out to you if your hole punching fails # this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT # Default is false + # This setting is reloadable. #respond: true # delays a punch response for misbehaving NATs, default is 1 second. + # This setting is reloadable. #delay: 1s # set the delay before attempting punchy.respond. Default is 5 seconds. respond must be true to take effect. + # This setting is reloadable. #respond_delay: 5s # Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes diff --git a/lighthouse.go b/lighthouse.go index 7cce47be..1a136a1b 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -15,7 +15,6 @@ import ( "time" "github.com/gaissmai/bart" - "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" @@ -35,7 +34,6 @@ type LightHouse struct { myVpnNetworks []netip.Prefix myVpnNetworksTable *bart.Lite - punchConn udp.Conn punchy *Punchy // Local cache of answers from light houses @@ -75,9 +73,8 @@ type LightHouse struct { calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote - metrics *MessageMetrics - metricHolepunchTx metrics.Counter - l *slog.Logger + metrics *MessageMetrics + l *slog.Logger } // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object @@ -105,7 +102,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c myVpnNetworksTable: cs.myVpnNetworksTable, addrMap: make(map[netip.Addr]*RemoteList), nebulaPort: nebulaPort, - punchConn: pc, punchy: p, updateTrigger: make(chan struct{}, 1), queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), @@ -118,9 +114,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c if c.GetBool("stats.lighthouse_metrics", false) { h.metrics = newLighthouseMetrics() - h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil) - } else { - h.metricHolepunchTx = metrics.NilCounter{} } err := h.reload(c, true) @@ -1406,70 +1399,25 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn return } - empty := []byte{0} - punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) { - if !vpnPeer.IsValid() { - return - } - - go func() { - t := time.NewTimer(lhh.lh.punchy.GetDelay()) - defer t.Stop() - select { - case <-lhh.lh.ctx.Done(): - return - case <-t.C: - } - lhh.lh.metricHolepunchTx.Inc(1) - lhh.lh.punchConn.WriteTo(empty, vpnPeer) - }() - - if lhh.l.Enabled(context.Background(), slog.LevelDebug) { - lhh.l.Debug("Punching", - "vpnPeer", vpnPeer, - "logVpnAddr", logVpnAddr, - ) - } - } - remoteAllowList := lhh.lh.GetRemoteAllowList() for _, a := range n.Details.V4AddrPorts { b := protoV4AddrPortToNetAddrPort(a) if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) { - punch(b, detailsVpnAddr) + lhh.lh.punchy.Schedule(b, detailsVpnAddr) } } for _, a := range n.Details.V6AddrPorts { b := protoV6AddrPortToNetAddrPort(a) if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) { - punch(b, detailsVpnAddr) + lhh.lh.punchy.Schedule(b, detailsVpnAddr) } } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish - // a tunnel. - if lhh.lh.punchy.GetRespond() { - go func() { - t := time.NewTimer(lhh.lh.punchy.GetRespondDelay()) - defer t.Stop() - select { - case <-lhh.lh.ctx.Done(): - return - case <-t.C: - } - if lhh.l.Enabled(context.Background(), slog.LevelDebug) { - lhh.l.Debug("Sending a nebula test packet", - "vpnAddr", detailsVpnAddr, - ) - } - //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine - // for each punchBack packet. We should move this into a timerwheel or a single goroutine - // managed by a channel. - w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - }() - } + // a tunnel. ScheduleRespond is a no-op when punchy.respond is disabled. + lhh.lh.punchy.ScheduleRespond(detailsVpnAddr) } func protoAddrToNetAddr(addr *Addr) netip.Addr { diff --git a/main.go b/main.go index 8373c44a..37aa24d1 100644 --- a/main.go +++ b/main.go @@ -170,7 +170,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev } hostMap := NewHostMapFromConfig(l, c) - punchy := NewPunchyFromConfig(l, c) + punchy := NewPunchyFromConfig(l, c, udpConns[0]) connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) if err != nil { @@ -240,6 +240,8 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev handshakeManager.f = ifce go handshakeManager.Run(ctx) + + punchy.Start(ctx, ifce, hostMap, lightHouse) } stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest) diff --git a/punchy.go b/punchy.go index 6ecf4f85..5f0a5fa2 100644 --- a/punchy.go +++ b/punchy.go @@ -1,24 +1,72 @@ package nebula import ( + "context" "log/slog" + "net/netip" "sync/atomic" "time" + "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" ) +const ( + holepunchTickDuration = 250 * time.Millisecond + holepunchWheelDuration = 60 * time.Second +) + +// holepunchJob is one scheduled item on the holepunch timer wheel. +// - target valid -> send a UDP punch to target. vpnAddr, if set, is the peer's vpn addr carried for log context. +// - target invalid, vpnAddr valid -> send an encrypted test packet to vpnAddr (a "punchback"). +type holepunchJob struct { + target netip.AddrPort + vpnAddr netip.Addr +} + +// lighthouseChecker is the slice of LightHouse that Punchy actually needs. +// Defined here so Punchy doesn't take a *LightHouse dependency (LightHouse +// already holds a *Punchy, and the bidirectional pointer reference is awkward +// even within the same package). Tests can also substitute a fake. +type lighthouseChecker interface { + IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool +} + type Punchy struct { punch atomic.Bool respond atomic.Bool delay atomic.Int64 respondDelay atomic.Int64 punchEverything atomic.Bool - l *slog.Logger + + holepunchTimer *LockingTimerWheel[holepunchJob] + punchConn udp.Conn + metricHolepunchTx metrics.Counter + metricPunchyTx metrics.Counter + + // Wired by Start, before any SendPunch* path can run. + ifce EncWriter + hm *HostMap + lh lighthouseChecker + + l *slog.Logger } -func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy { - p := &Punchy{l: l} +func NewPunchyFromConfig(l *slog.Logger, c *config.C, punchConn udp.Conn) *Punchy { + p := &Punchy{ + l: l, + punchConn: punchConn, + holepunchTimer: NewLockingTimerWheel[holepunchJob](holepunchTickDuration, holepunchWheelDuration), + metricPunchyTx: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), + } + + if c.GetBool("stats.lighthouse_metrics", false) { + p.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil) + } else { + p.metricHolepunchTx = metrics.NilCounter{} + } p.reload(c, true) c.RegisterReloadCallback(func(c *config.C) { @@ -29,7 +77,7 @@ func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy { } func (p *Punchy) reload(c *config.C, initial bool) { - if initial { + if initial || c.HasChanged("punchy.punch") || c.HasChanged("punchy") { var yes bool if c.IsSet("punchy.punch") { yes = c.GetBool("punchy.punch", false) @@ -38,16 +86,15 @@ func (p *Punchy) reload(c *config.C, initial bool) { yes = c.GetBool("punchy", false) } - p.punch.Store(yes) - if yes { + old := p.punch.Swap(yes) + switch { + case initial && yes: p.l.Info("punchy enabled") - } else { + case initial: p.l.Info("punchy disabled") + case old != yes: + p.l.Info("punchy.punch changed", "punch", yes) } - - } else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") { - //TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here - p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.") } if initial || c.HasChanged("punchy.respond") || c.HasChanged("punch_back") { @@ -59,52 +106,143 @@ func (p *Punchy) reload(c *config.C, initial bool) { yes = c.GetBool("punch_back", false) } - p.respond.Store(yes) - - if !initial { - p.l.Info("punchy.respond changed", "respond", p.GetRespond()) + old := p.respond.Swap(yes) + if !initial && old != yes { + p.l.Info("punchy.respond changed", "respond", yes) } } //NOTE: this will not apply to any in progress operations, only the next one if initial || c.HasChanged("punchy.delay") { - p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second))) - if !initial { - p.l.Info("punchy.delay changed", "delay", p.GetDelay()) + newDelay := int64(c.GetDuration("punchy.delay", time.Second)) + old := p.delay.Swap(newDelay) + if !initial && old != newDelay { + p.l.Info("punchy.delay changed", "delay", time.Duration(newDelay)) } } if initial || c.HasChanged("punchy.target_all_remotes") { - p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false)) - if !initial { - p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything()) + yes := c.GetBool("punchy.target_all_remotes", false) + old := p.punchEverything.Swap(yes) + if !initial && old != yes { + p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", yes) } } if initial || c.HasChanged("punchy.respond_delay") { - p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second))) - if !initial { - p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay()) + newDelay := int64(c.GetDuration("punchy.respond_delay", 5*time.Second)) + old := p.respondDelay.Swap(newDelay) + if !initial && old != newDelay { + p.l.Info("punchy.respond_delay changed", "respond_delay", time.Duration(newDelay)) } } } -func (p *Punchy) GetPunch() bool { - return p.punch.Load() +// Schedule queues a punch packet to target, to be sent after the configured delay. +// vpnAddr is the peer's vpn addr, carried through for log context when the packet actually fires. +// No-op if target is not a valid AddrPort. Safe to call from any goroutine. +func (p *Punchy) Schedule(target netip.AddrPort, vpnAddr netip.Addr) { + if !target.IsValid() { + return + } + p.holepunchTimer.Add(holepunchJob{target: target, vpnAddr: vpnAddr}, time.Duration(p.delay.Load())) } -func (p *Punchy) GetRespond() bool { - return p.respond.Load() +// ScheduleRespond queues a punchback test packet to vpnAddr after the configured respond delay, +// gated on punchy.respond. No-op when respond is disabled. +func (p *Punchy) ScheduleRespond(vpnAddr netip.Addr) { + if !p.respond.Load() { + return + } + p.holepunchTimer.Add(holepunchJob{vpnAddr: vpnAddr}, time.Duration(p.respondDelay.Load())) } -func (p *Punchy) GetDelay() time.Duration { - return (time.Duration)(p.delay.Load()) +// SendPunch sends an immediate keepalive punch for an idle hostinfo. +// The configured punchy.target_all_remotes mode picks the targets. Gated on punchy.punch and the lighthouse-skip rule +// (lighthouses don't get keepalive punches because the regular update interval keeps their NAT state warm). +func (p *Punchy) SendPunch(hostinfo *HostInfo) { + if !p.punch.Load() { + return + } + if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) { + return + } + + if p.punchEverything.Load() { + p.sendPunchToAllRemotes(hostinfo) + } else if hostinfo.remote.IsValid() { + p.metricPunchyTx.Inc(1) + p.punchConn.WriteTo([]byte{1}, hostinfo.remote) + } } -func (p *Punchy) GetRespondDelay() time.Duration { - return (time.Duration)(p.respondDelay.Load()) +// SendPunchToAll punches every known remote for hostinfo, but only when punchy.target_all_remotes is enabled. +// The connection manager calls this during outbound-only traffic: the outbound traffic itself keeps the primary's +// NAT state warm, but non-primary remotes need separate refresh, so we fan out to all of them (the redundant +// primary punch is harmless). Gated on punchy.punch and the lighthouse-skip rule. +func (p *Punchy) SendPunchToAll(hostinfo *HostInfo) { + if !p.punchEverything.Load() { + return + } + if !p.punch.Load() { + return + } + if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) { + return + } + p.sendPunchToAllRemotes(hostinfo) } -func (p *Punchy) GetTargetEverything() bool { - return p.punchEverything.Load() +func (p *Punchy) sendPunchToAllRemotes(hostinfo *HostInfo) { + hostinfo.remotes.ForEach(p.hm.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { + p.metricPunchyTx.Inc(1) + p.punchConn.WriteTo([]byte{1}, addr) + }) +} + +// Start wires the runtime dependencies and runs a single goroutine that drains the holepunch timer wheel. +// Must be called after the interface is up. Exits when ctx is cancelled. +func (p *Punchy) Start(ctx context.Context, ifce EncWriter, hm *HostMap, lh lighthouseChecker) { + p.ifce = ifce + p.hm = hm + p.lh = lh + + go func() { + clockSource := time.NewTicker(holepunchTickDuration) + defer clockSource.Stop() + + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + empty := []byte{0} + + for { + select { + case <-ctx.Done(): + return + case now := <-clockSource.C: + p.holepunchTimer.Advance(now) + for { + job, has := p.holepunchTimer.Purge() + if !has { + break + } + switch { + case job.target.IsValid(): + if p.l.Enabled(context.Background(), slog.LevelDebug) { + p.l.Debug("Punching", "target", job.target, "vpnAddr", job.vpnAddr) + } + p.metricHolepunchTx.Inc(1) + p.punchConn.WriteTo(empty, job.target) + case job.vpnAddr.IsValid(): + // A nebula test packet to the host trying to contact us. In the case of a double nat or other + // difficult scenario, this may help establish a tunnel. + if p.l.Enabled(context.Background(), slog.LevelDebug) { + p.l.Debug("Sending a nebula test packet", "vpnAddr", job.vpnAddr) + } + p.ifce.SendMessageToVpnAddr(header.Test, header.TestRequest, job.vpnAddr, []byte(""), nb, out) + } + } + } + } + }() } diff --git a/punchy_test.go b/punchy_test.go index cbf9b17b..e56f3eff 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -17,42 +17,42 @@ func TestNewPunchyFromConfig(t *testing.T) { c := config.NewC(l) // Test defaults - p := NewPunchyFromConfig(test.NewLogger(), c) - assert.False(t, p.GetPunch()) - assert.False(t, p.GetRespond()) - assert.Equal(t, time.Second, p.GetDelay()) - assert.Equal(t, 5*time.Second, p.GetRespondDelay()) + p := NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.False(t, p.punch.Load()) + assert.False(t, p.respond.Load()) + assert.Equal(t, time.Second, time.Duration(p.delay.Load())) + assert.Equal(t, 5*time.Second, time.Duration(p.respondDelay.Load())) // punchy deprecation c.Settings["punchy"] = true - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.True(t, p.GetPunch()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.True(t, p.punch.Load()) // punchy.punch c.Settings["punchy"] = map[string]any{"punch": true} - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.True(t, p.GetPunch()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.True(t, p.punch.Load()) // punch_back deprecation c.Settings["punch_back"] = true - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.True(t, p.GetRespond()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.True(t, p.respond.Load()) // punchy.respond c.Settings["punchy"] = map[string]any{"respond": true} c.Settings["punch_back"] = false - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.True(t, p.GetRespond()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.True(t, p.respond.Load()) // punchy.delay c.Settings["punchy"] = map[string]any{"delay": "1m"} - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.Equal(t, time.Minute, p.GetDelay()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.Equal(t, time.Minute, time.Duration(p.delay.Load())) // punchy.respond_delay c.Settings["punchy"] = map[string]any{"respond_delay": "1m"} - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.Equal(t, time.Minute, p.GetRespondDelay()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.Equal(t, time.Minute, time.Duration(p.respondDelay.Load())) } func TestPunchy_reload(t *testing.T) { @@ -61,35 +61,34 @@ func TestPunchy_reload(t *testing.T) { delay, _ := time.ParseDuration("1m") require.NoError(t, c.LoadString(` punchy: + punch: false delay: 1m respond: false `)) - p := NewPunchyFromConfig(test.NewLogger(), c) - assert.Equal(t, delay, p.GetDelay()) - assert.False(t, p.GetRespond()) + p := NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.False(t, p.punch.Load()) + assert.Equal(t, delay, time.Duration(p.delay.Load())) + assert.False(t, p.respond.Load()) newDelay, _ := time.ParseDuration("10m") require.NoError(t, c.ReloadConfigString(` punchy: + punch: true delay: 10m respond: true `)) p.reload(c, false) - assert.Equal(t, newDelay, p.GetDelay()) - assert.True(t, p.GetRespond()) + assert.True(t, p.punch.Load()) + assert.Equal(t, newDelay, time.Duration(p.delay.Load())) + assert.True(t, p.respond.Load()) } // The tests below pin the shape of each log line Punchy produces so changes // cannot silently break whatever operators are grepping for. The assertions // are on the structured message + attrs (e.g. "punchy.respond changed" with -// a respond=true field) rather than a formatted string. -// -// Punchy.reload also emits a spurious "Changing punchy.punch with reload is -// not supported" warning whenever any key under punchy changes, because of -// the c.HasChanged("punchy") fallback kept for the deprecated top-level -// punchy form. The tests filter by message rather than asserting total -// entry counts so that warning is tolerated without being locked into -// the format. +// a respond=true field) rather than a formatted string. Tests filter by +// message rather than asserting total entry counts so unrelated info lines +// are tolerated without being locked into the format. type capturedEntry struct { Level slog.Level @@ -145,7 +144,7 @@ func TestPunchy_LogFormat_InitialEnabled(t *testing.T) { c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {punch: true}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) entry := findEntry(t, hook.entries, "punchy enabled") assert.Equal(t, slog.LevelInfo, entry.Level) @@ -157,32 +156,32 @@ func TestPunchy_LogFormat_InitialDisabled(t *testing.T) { c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {punch: false}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) entry := findEntry(t, hook.entries, "punchy disabled") assert.Equal(t, slog.LevelInfo, entry.Level) assert.Empty(t, entry.Attrs) } -func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) { +func TestPunchy_LogFormat_ReloadPunch(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {punch: false}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`)) - entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.") - assert.Equal(t, slog.LevelWarn, entry.Level) - assert.Empty(t, entry.Attrs) + entry := findEntry(t, hook.entries, "punchy.punch changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"punch": true}, entry.Attrs) } func TestPunchy_LogFormat_ReloadRespond(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {respond: false}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`)) @@ -196,7 +195,7 @@ func TestPunchy_LogFormat_ReloadDelay(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {delay: 1s}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`)) @@ -210,7 +209,7 @@ func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`)) @@ -224,7 +223,7 @@ func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`))