diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 9a17b947..efbcc8b8 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -78,8 +78,16 @@ func main() { } if !*configTest { - ctrl.Start() - ctrl.ShutdownBlock() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() + wait() + + l.Info("Goodbye") } os.Exit(0) diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index ffdc15bf..98913ed8 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -3,6 +3,9 @@ package main import ( "flag" "fmt" + "log" + "net/http" + _ "net/http/pprof" "os" "runtime/debug" "strings" @@ -71,10 +74,22 @@ func main() { os.Exit(1) } + go func() { + log.Println(http.ListenAndServe("0.0.0.0:6060", nil)) + }() + if !*configTest { - ctrl.Start() + wait, err := ctrl.Start() + if err != nil { + util.LogWithContextIfNeeded("Error while running", err, l) + os.Exit(1) + } + + go ctrl.ShutdownBlock() notifyReady(l) - ctrl.ShutdownBlock() + wait() + + l.Info("Goodbye") } os.Exit(0) diff --git a/control.go b/control.go index f8567b50..2d07de59 100644 --- a/control.go +++ b/control.go @@ -2,9 +2,11 @@ package nebula import ( "context" + "errors" "net/netip" "os" "os/signal" + "sync" "syscall" "github.com/sirupsen/logrus" @@ -13,6 +15,16 @@ import ( "github.com/slackhq/nebula/overlay" ) +type RunState int + +const ( + Stopped RunState = 0 // The control has yet to be started + Started RunState = 1 // The control has been started + Stopping RunState = 2 // The control is stopping +) + +var ErrAlreadyStarted = errors.New("nebula is already started") + // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc @@ -26,6 +38,9 @@ type controlHostLister interface { } type Control struct { + stateLock sync.Mutex + state RunState + f *Interface l *logrus.Logger ctx context.Context @@ -49,10 +64,21 @@ type ControlHostInfo struct { CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } -// Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() -func (c *Control) Start() { +// Start actually runs nebula, this is a nonblocking call. +// The returned function can be used to wait for nebula to fully stop. +func (c *Control) Start() (func(), error) { + c.stateLock.Lock() + if c.state != Stopped { + c.stateLock.Unlock() + return nil, ErrAlreadyStarted + } + // Activate the interface - c.f.activate() + err := c.f.activate() + if err != nil { + c.stateLock.Unlock() + return nil, err + } // Call all the delayed funcs that waited patiently for the interface to be created. if c.sshStart != nil { @@ -72,15 +98,33 @@ func (c *Control) Start() { } // Start reading packets. - c.f.run() + c.state = Started + c.stateLock.Unlock() + return c.f.run() +} + +func (c *Control) State() RunState { + c.stateLock.Lock() + defer c.stateLock.Unlock() + return c.state } func (c *Control) Context() context.Context { return c.ctx } -// Stop signals nebula to shutdown and close all tunnels, returns after the shutdown is complete +// Stop is a non-blocking call that signals nebula to close all tunnels and shut down func (c *Control) Stop() { + c.stateLock.Lock() + if c.state != Started { + c.stateLock.Unlock() + // We are stopping or stopped already + return + } + + c.state = Stopping + c.stateLock.Unlock() + // Stop the handshakeManager (and other services), to prevent new tunnels from // being created while we're shutting them all down. c.cancel() @@ -89,7 +133,7 @@ func (c *Control) Stop() { if err := c.f.Close(); err != nil { c.l.WithError(err).Error("Close interface failed") } - c.l.Info("Goodbye") + c.state = Stopped } // ShutdownBlock will listen for and block on term and interrupt signals, calling Control.Stop() once signalled diff --git a/interface.go b/interface.go index 0194a945..59b389ea 100644 --- a/interface.go +++ b/interface.go @@ -6,8 +6,8 @@ import ( "fmt" "io" "net/netip" - "os" "runtime" + "sync" "sync/atomic" "time" @@ -87,6 +87,7 @@ type Interface struct { writers []udp.Conn readers []io.ReadWriteCloser + wg sync.WaitGroup metricHandshakes metrics.Histogram messageMetrics *MessageMetrics @@ -209,7 +210,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { // activate creates the interface on the host. After the interface is created, any // other services that want to bind listeners to its IP may do so successfully. However, // the interface isn't going to process anything until run() is called. -func (f *Interface) activate() { +func (f *Interface) activate() error { // actually turn on tun dev addr, err := f.outside.LocalAddr() @@ -237,33 +238,38 @@ func (f *Interface) activate() { if i > 0 { reader, err = f.inside.NewMultiQueueReader() if err != nil { - f.l.Fatal(err) + return err } } f.readers[i] = reader } - if err := f.inside.Activate(); err != nil { + if err = f.inside.Activate(); err != nil { f.inside.Close() - f.l.Fatal(err) + return err } + + return nil } -func (f *Interface) run() { +func (f *Interface) run() (func(), error) { // Launch n queues to read packets from udp for i := 0; i < f.routines; i++ { + f.wg.Add(1) go f.listenOut(i) } // Launch n queues to read packets from tun dev for i := 0; i < f.routines; i++ { + f.wg.Add(1) go f.listenIn(f.readers[i], i) } + + return f.wg.Wait, nil } func (f *Interface) listenOut(i int) { runtime.LockOSThread() - var li udp.Conn if i > 0 { li = f.writers[i] @@ -278,14 +284,21 @@ func (f *Interface) listenOut(i int) { fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { + err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) }) + + if err != nil && !f.closed.Load() { + f.l.WithError(err).Error("Error while reading packet inbound packet, closing") + //TODO: Trigger Control to close + } + + f.l.Debugf("underlay reader %v is done", i) + f.wg.Done() } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { runtime.LockOSThread() - packet := make([]byte, mtu) out := make([]byte, mtu) fwPacket := &firewall.Packet{} @@ -296,17 +309,18 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { for { n, err := reader.Read(packet) if err != nil { - if errors.Is(err, os.ErrClosed) && f.closed.Load() { - return + if !f.closed.Load() { + f.l.WithError(err).Error("Error while reading outbound packet, closing") + //TODO: Trigger Control to close } - - f.l.WithError(err).Error("Error while reading outbound packet") - // This only seems to happen when something fatal happens to the fd, so exit. - os.Exit(2) + break } f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get(f.l)) } + + f.l.Debugf("overlay reader %v is done", i) + f.wg.Done() } func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { @@ -458,6 +472,7 @@ func (f *Interface) GetCertState() *CertState { func (f *Interface) Close() error { f.closed.Store(true) + // Release the udp readers for _, u := range f.writers { err := u.Close() if err != nil { @@ -465,6 +480,13 @@ func (f *Interface) Close() error { } } - // Release the tun device - return f.inside.Close() + // Release the tun readers + for _, u := range f.readers { + err := u.Close() + if err != nil { + f.l.WithError(err).Error("Error while closing tun device") + } + } + + return nil } diff --git a/main.go b/main.go index 7b326616..37d61cd4 100644 --- a/main.go +++ b/main.go @@ -291,15 +291,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } return &Control{ - ifce, - l, - ctx, - cancel, - sshStart, - statsStart, - dnsStart, - lightHouse.StartUpdateWorker, - connManager.Start, + f: ifce, + l: l, + ctx: ctx, + cancel: cancel, + sshStart: sshStart, + statsStart: statsStart, + dnsStart: dnsStart, + lighthouseStart: lightHouse.StartUpdateWorker, + connectionManagerStart: connManager.Start, }, nil } diff --git a/outside.go b/outside.go index 5ff87bd8..1b297c01 100644 --- a/outside.go +++ b/outside.go @@ -29,7 +29,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] return } - //l.Error("in packet ", header, packet[HeaderLen:]) + //f.l.Error("in packet ", h) if ip.IsValid() { if f.myVpnNetworksTable.Contains(ip.Addr()) { if f.l.Level >= logrus.DebugLevel { diff --git a/service/service.go b/service/service.go index fc8ac97a..c86d08c3 100644 --- a/service/service.go +++ b/service/service.go @@ -44,7 +44,10 @@ type Service struct { } func New(control *nebula.Control) (*Service, error) { - control.Start() + wait, err := control.Start() + if err != nil { + return nil, err + } ctx := control.Context() eg, ctx := errgroup.WithContext(ctx) @@ -141,6 +144,12 @@ func New(control *nebula.Control) (*Service, error) { } }) + // Add the nebula wait function to the group + eg.Go(func() error { + wait() + return nil + }) + return &s, nil } diff --git a/udp/conn.go b/udp/conn.go index 1ae585c2..30d89dec 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -16,7 +16,7 @@ type EncReader func( type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader) + ListenOut(r EncReader) error WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) SupportsMultipleReaders() bool @@ -31,8 +31,8 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) { - return +func (NoopConn) ListenOut(_ EncReader) error { + return nil } func (NoopConn) SupportsMultipleReaders() bool { return false diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 91201194..863c98f3 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -165,7 +165,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() {} } -func (u *StdConn) ListenOut(r EncReader) { +func (u *StdConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) for { @@ -173,8 +173,7 @@ func (u *StdConn) ListenOut(r EncReader) { n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } u.l.WithError(err).Error("unexpected udp socket receive error") diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 3cefc904..44632fed 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -71,15 +71,14 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader) { +func (u *GenericConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) for { // Just read one packet at a time n, rua, err := u.ReadFromUDPAddrPort(buffer) if err != nil { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index e7759329..32b9c69b 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -9,6 +9,7 @@ import ( "net" "net/netip" "syscall" + "time" "unsafe" "github.com/rcrowley/go-metrics" @@ -17,6 +18,8 @@ import ( "golang.org/x/sys/unix" ) +var readTimeout = unix.NsecToTimeval(int64(time.Millisecond * 500)) + type StdConn struct { sysFd int isV4 bool @@ -24,14 +27,6 @@ type StdConn struct { batch int } -func maybeIPV4(ip net.IP) (net.IP, bool) { - ip4 := ip.To4() - if ip4 != nil { - return ip4, true - } - return ip, false -} - func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { af := unix.AF_INET6 if ip.Is4() { @@ -55,6 +50,11 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in } } + // Set a read timeout + if err = unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &readTimeout); err != nil { + return nil, fmt.Errorf("unable to set SO_RCVTIMEO: %s", err) + } + var sa unix.Sockaddr if ip.Is4() { sa4 := &unix.SockaddrInet4{Port: port} @@ -122,7 +122,7 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { } } -func (u *StdConn) ListenOut(r EncReader) { +func (u *StdConn) ListenOut(r EncReader) error { var ip netip.Addr msgs, buffers, names := u.PrepareRawMessages(u.batch) @@ -134,8 +134,7 @@ func (u *StdConn) ListenOut(r EncReader) { for { n, err := read(msgs) if err != nil { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } for i := 0; i < n; i++ { @@ -163,6 +162,9 @@ func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { ) if err != 0 { + if err == unix.EAGAIN || err == unix.EINTR { + continue + } return 0, &net.OpError{Op: "recvmsg", Err: err} } @@ -184,6 +186,9 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { ) if err != 0 { + if err == unix.EAGAIN || err == unix.EINTR { + continue + } return 0, &net.OpError{Op: "recvmmsg", Err: err} } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 1d602d01..0b7ae4b7 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -134,7 +134,7 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader) { +func (u *RIOConn) ListenOut(r EncReader) error { buffer := make([]byte, MTU) for { @@ -142,8 +142,7 @@ func (u *RIOConn) ListenOut(r EncReader) { n, rua, err := u.receive(buffer) if err != nil { if errors.Is(err, net.ErrClosed) { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + return err } u.l.WithError(err).Error("unexpected udp socket receive error") continue diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 5f0f7765..5db72555 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -6,6 +6,7 @@ package udp import ( "io" "net/netip" + "os" "sync/atomic" "github.com/sirupsen/logrus" @@ -106,11 +107,11 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } -func (u *TesterConn) ListenOut(r EncReader) { +func (u *TesterConn) ListenOut(r EncReader) error { for { p, ok := <-u.RxPackets if !ok { - return + return os.ErrClosed } r(p.From, p.Data) }