diff --git a/examples/config.yml b/examples/config.yml index 3ba9ce4e..f81baab6 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -144,6 +144,10 @@ listen: # valid values: always, never, private # This setting is reloadable. #send_recv_error: always + # Similar to send_recv_error, this option lets you configure if you want to accept "recv_error" packets from remote hosts. + # valid values: always, never, private + # This setting is reloadable. + #accept_recv_error: always # The so_sock option is a Linux-specific feature that allows all outgoing Nebula packets to be tagged with a specific identifier. # This tagging enables IP rule-based filtering. For example, it supports 0.0.0.0/0 unsafe_routes, # allowing for more precise routing decisions based on the packet tags. Default is 0 meaning no mark is set. diff --git a/interface.go b/interface.go index 9f83d183..f69ed062 100644 --- a/interface.go +++ b/interface.go @@ -77,7 +77,8 @@ type Interface struct { reQueryEvery atomic.Uint32 reQueryWait atomic.Int64 - sendRecvErrorConfig sendRecvErrorConfig + sendRecvErrorConfig recvErrorConfig + acceptRecvErrorConfig recvErrorConfig // rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse rebindCount int8 @@ -110,34 +111,34 @@ type EncWriter interface { GetCertState() *CertState } -type sendRecvErrorConfig uint8 +type recvErrorConfig uint8 const ( - sendRecvErrorAlways sendRecvErrorConfig = iota - sendRecvErrorNever - sendRecvErrorPrivate + recvErrorAlways recvErrorConfig = iota + recvErrorNever + recvErrorPrivate ) -func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool { +func (s recvErrorConfig) ShouldRecvError(endpoint netip.AddrPort) bool { switch s { - case sendRecvErrorPrivate: + case recvErrorPrivate: return endpoint.Addr().IsPrivate() - case sendRecvErrorAlways: + case recvErrorAlways: return true - case sendRecvErrorNever: + case recvErrorNever: return false default: - panic(fmt.Errorf("invalid sendRecvErrorConfig value: %d", s)) + panic(fmt.Errorf("invalid recvErrorConfig value: %d", s)) } } -func (s sendRecvErrorConfig) String() string { +func (s recvErrorConfig) String() string { switch s { - case sendRecvErrorAlways: + case recvErrorAlways: return "always" - case sendRecvErrorNever: + case recvErrorNever: return "never" - case sendRecvErrorPrivate: + case recvErrorPrivate: return "private" default: return fmt.Sprintf("invalid(%d)", s) @@ -312,6 +313,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) + c.RegisterReloadCallback(f.reloadAcceptRecvError) c.RegisterReloadCallback(f.reloadDisconnectInvalid) c.RegisterReloadCallback(f.reloadMisc) @@ -375,16 +377,16 @@ func (f *Interface) reloadSendRecvError(c *config.C) { switch stringValue { case "always": - f.sendRecvErrorConfig = sendRecvErrorAlways + f.sendRecvErrorConfig = recvErrorAlways case "never": - f.sendRecvErrorConfig = sendRecvErrorNever + f.sendRecvErrorConfig = recvErrorNever case "private": - f.sendRecvErrorConfig = sendRecvErrorPrivate + f.sendRecvErrorConfig = recvErrorPrivate default: if c.GetBool("listen.send_recv_error", true) { - f.sendRecvErrorConfig = sendRecvErrorAlways + f.sendRecvErrorConfig = recvErrorAlways } else { - f.sendRecvErrorConfig = sendRecvErrorNever + f.sendRecvErrorConfig = recvErrorNever } } @@ -393,6 +395,30 @@ func (f *Interface) reloadSendRecvError(c *config.C) { } } +func (f *Interface) reloadAcceptRecvError(c *config.C) { + if c.InitialLoad() || c.HasChanged("listen.accept_recv_error") { + stringValue := c.GetString("listen.accept_recv_error", "always") + + switch stringValue { + case "always": + f.acceptRecvErrorConfig = recvErrorAlways + case "never": + f.acceptRecvErrorConfig = recvErrorNever + case "private": + f.acceptRecvErrorConfig = recvErrorPrivate + default: + if c.GetBool("listen.accept_recv_error", true) { + f.acceptRecvErrorConfig = recvErrorAlways + } else { + f.acceptRecvErrorConfig = recvErrorNever + } + } + + f.l.WithField("acceptRecvError", f.acceptRecvErrorConfig.String()). + Info("Loaded accept_recv_error config") + } +} + func (f *Interface) reloadMisc(c *config.C) { if c.HasChanged("counters.try_promote") { n := c.GetUint32("counters.try_promote", defaultPromoteEvery) diff --git a/main.go b/main.go index 7b326616..17aaa548 100644 --- a/main.go +++ b/main.go @@ -265,6 +265,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg ifce.RegisterConfigChangeCallbacks(c) ifce.reloadDisconnectInvalid(c) ifce.reloadSendRecvError(c) + ifce.reloadAcceptRecvError(c) handshakeManager.f = ifce go handshakeManager.Run(ctx) diff --git a/outside.go b/outside.go index b1a28e57..172c3e83 100644 --- a/outside.go +++ b/outside.go @@ -516,7 +516,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out } func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) { - if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint) { + if f.sendRecvErrorConfig.ShouldRecvError(endpoint) { f.sendRecvError(endpoint, index) } } @@ -534,6 +534,13 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { } func (f *Interface) handleRecvError(addr netip.AddrPort, h *header.H) { + if !f.acceptRecvErrorConfig.ShouldRecvError(addr) { + f.l.WithField("index", h.RemoteIndex). + WithField("udpAddr", addr). + Debug("Recv error received, ignoring") + return + } + if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", h.RemoteIndex). WithField("udpAddr", addr).